#!/usr/bin/env python # coding: utf-8 # ## Evaluating structured outputs in LangChain # # In[1]: from langchain_core.prompts import ChatPromptTemplate from langchain_anthropic import ChatAnthropic from langchain_ollama import ChatOllama from pydantic import BaseModel, Field # Experiment parameters # # In[ ]: claude_api_key = "" experiment_date = "9-12-24" n_iter = 2 # ### Prompt and problem setup # # For this test I’m going to start with a substitute task to write an article for a magazine and provide the response for different questions in a specific format. # # Here we specify the prompt and any inputs to use to vary the problem (the list of questions).0 # # In[ ]: test_science_prompt_txt = """ You are a professional science writer tasked with responding to members of the general public who write in asking questions about science. Write an article responding to a writer's question for publication in a science magazine intended for a general readership with a high-school education. You should write clearly and compellingly, include all relavent context, and provide motivating stories where applicable. Your response must be less than 200 words. The question given to you is the following: {question} """ questions = [ "What is the oldest recorded fossil?", "What is a black hole?", "How far away is the sun?", "Which other planet in the Solar System has a surface gravity closest to that of the Earth?", "Eris, Haumea, Makemake and Ceres are all examples of what?", "Why does earth have seasons? Do other planets exhibit seasons too?", "What causes the aurora borealis?", "Why is the sky blue?", "How do bees communicate?", "What is the smallest unit of life?", "How do plants make their own food?", "Why do we dream?", "What is the theory of relativity?", "How do volcanoes erupt?", "What is the speed of light?", "How do magnets work?", "What is the purpose of DNA?", "What are the different types of galaxies?", "Why do some animals hibernate?", "How do vaccines work?", ] prompt_direct = ChatPromptTemplate.from_template(test_science_prompt_txt) prompt_system_format = ChatPromptTemplate.from_messages( [ ( "system", "Answer the user query.\n{format_instructions}", ), ("human", test_science_prompt_txt), ] ) prompt_user_format = ChatPromptTemplate.from_template( test_science_prompt_txt + "\n{format_instructions}" ) # ### JSON output format specs # # #### Pydantic structures # # To answer the question of how these models and output methods differ with different complexities of schema I’m defining four example schema in increasing order of complexity # # In[ ]: # Simple types class ArticleResponse1(BaseModel): """Structured article for publication answering a reader's question.""" title: str = Field(description="Title of the article") answer: str = Field( description="Provide a detailed description of historical events to answer the question." ) number: int = Field( description="An arbitraty number that is most relevant to the question." ) # Nested types class HistoricalEvent(BaseModel): """The year and explanation of a historical event.""" year: int = Field(description="The year of the historical event") description: str = Field( description="A clear description of what happened in this event" ) class ArticleResponse2(BaseModel): """Structured article for publication answering a reader's question.""" title: str = Field(description="Title of the article") historical_event_1: HistoricalEvent = Field( description="Provide a detailed description of one historical events to answer the question." ) historical_event_2: HistoricalEvent = Field( description="Provide a detailed description of one historical events to answer the question." ) # Lists of simple types class ArticleResponse3(BaseModel): """Structured article for publication answering a reader's question.""" title: str = Field(description="Title of the article") further_questions: list[str] = Field( description="A list of related questions that may be of interest to the readers." ) # Lists of custom types class ArticleResponse4(BaseModel): """Structured article for publication answering a reader's question.""" title: str = Field(description="Title of the article") historical_timeline: list[HistoricalEvent] = Field( description="Provide a compelling account of the historical context of the question" ) structured_formats = [ dict(pydantic=ArticleResponse1), dict(pydantic=ArticleResponse2), dict(pydantic=ArticleResponse3), dict(pydantic=ArticleResponse4), ] # ### Models to evaluate # # In[ ]: # Default temperature temperature = 0.8 # In[ ]: llm_models = { # "Anthropic_Sonnet": ChatAnthropic( # model="claude-3-5-sonnet-20241022", api_key=claude_api_key # ), # "Anthropic_Haiku": ChatAnthropic(model="claude-3-5-haiku-20241022", api_key=claude_api_key), # "Anthropic_Haiku": ChatAnthropic( # model="claude-3-haiku-20240307", api_key=claude_api_key # ), "Ollama_llama32": ChatOllama(model="llama3.2", temperature=temperature), "nemotron-mini": ChatOllama(model="nemotron-mini", temperature=temperature), "Ollama_gemma2": ChatOllama(model="gemma2", temperature=temperature), "Ollama_phi3": ChatOllama(model="phi3", temperature=temperature), } llm_models_jsonmode = { "Ollama_llama32": ChatOllama( model="llama3.2", format="json", temperature=temperature ), "nemotron-mini": ChatOllama( model="nemotron-mini", format="json", temperature=temperature ), "Ollama_gemma2": ChatOllama(model="gemma2", format="json", temperature=temperature), "Ollama_phi3": ChatOllama(model="phi3", format="json", temperature=temperature), } # ## Evaluation # # Let's loop over different structured outputs and check the adherence using the tool-calling API (structured output mode) # # ### Evaluate Tool Calling API with Pydantic objects # # Question - of the models that have tool calling, what complexity of structure can they support? # # #### Method 1 : Tool-calling API # # In[ ]: structure_support_by_model = {} n_questions = len(questions) for model_name, llm_model in llm_models.items(): structure_support_by_model[model_name] = {} for structure in structured_formats: pydantic_obj = structure["pydantic"] print(f"Model: {model_name} Output: {pydantic_obj.__name__}") # Iterate over questions output_valid = 0 tool_use = 0 error_messages = [] outputs = [] for kk in range(n_iter): for ii in range(n_questions): test_chain = prompt_direct | llm_model.with_structured_output( pydantic_obj, include_raw=True ) try: output = test_chain.invoke(dict(question=questions[ii])) tool_use += 1 if output["parsing_error"] is None: output_valid += 1 else: print(output["parsing_error"]) error_messages.append(output["parsing_error"]) outputs.append(output) except Exception as e: print(f" Tool use error \n{type(e)}.__name__: {e}") structure_support_by_model[model_name][pydantic_obj.__name__] = dict( valid=output_valid / (n_iter * n_questions), tool_use=tool_use / (n_iter * n_questions), errors=error_messages, outputs=outputs, ) # #### Method 2 : Output parser # # Let's do the same for the output parser formatting. Note that as a lot of models seem to ignore this, it takes a lot of time. # # In[ ]: from langchain_core.output_parsers import PydanticOutputParser def run_experiment_with_op(prompt_format, llm_models, n_iter): ss_results = {} n_questions = len(questions) for model_name, llm_model in llm_models.items(): ss_results[model_name] = {} for structure in structured_formats: pydantic_obj = structure["pydantic"] print(f"Model: {model_name} Output: {pydantic_obj.__name__}") # Iterate over questions output_valid = 0 tool_use = 0 error_messages = [] outputs = [] for kk in range(n_iter): for ii in range(n_questions): parser = PydanticOutputParser(pydantic_object=pydantic_obj) prompt = prompt_format.partial( format_instructions=parser.get_format_instructions() ) test_chain = prompt | llm_model | parser try: output = test_chain.invoke(dict(question=questions[ii])) assert isinstance(output, pydantic_obj) output_valid += 1 outputs.append(output) except Exception as e: print(f" Invalid ouput ({type(e)})") error_messages.append(f"{type(e).__name__}, {e}") ss_results[model_name][pydantic_obj.__name__] = dict( valid=output_valid / (n_iter * n_questions), tool_use=tool_use / (n_iter * n_questions), errors=error_messages, outputs=outputs, ) return ss_results # In[ ]: structure_support_by_model_op = run_experiment_with_op( prompt_user_format, llm_models, n_iter ) # In[ ]: structure_support_by_model_op_jsonmode = run_experiment_with_op( prompt_user_format, llm_models_jsonmode, n_iter ) # In[ ]: structure_support_by_model_op_system = run_experiment_with_op( prompt_system_format, llm_models, n_iter ) # ### Error analysis # # In[ ]: import pandas as pd # In[ ]: def results_to_df(ss_results, key="valid"): df = pd.DataFrame.from_dict( { mname: { tname: ss_results[mname][tname][key] * 100 / n_questions for tname in ss_results[mname].keys() } for mname in ss_results.keys() }, orient="index", ) return df def analyse_errors_from_results(ss_results, key="errors"): error_counts = {} for mname in ss_results.keys(): error_counts[mname] = {} for tname in ss_results[mname].keys(): validation_error = 0 json_error = 0 unknown_error = 0 errors = ss_results[mname][tname][key] for error in errors: error_str = str(error) if error_str.lower().find("invalid json output") >= 0: json_error += 1 elif error_str.lower().find("validation error") >= 0: validation_error += 1 else: unknown_error += 1 error_counts[mname][(tname, "invalid_json")] = json_error error_counts[mname][(tname, "validation")] = validation_error error_counts[mname][(tname, "unknown")] = unknown_error return pd.DataFrame.from_dict(error_counts, orient="index") # In[ ]: errors_df = analyse_errors_from_results(structure_support_by_model_op, "errors") errors_df # In[ ]: errors_df = analyse_errors_from_results(structure_support_by_model_op_system, "errors") errors_df # In[ ]: errors_df = analyse_errors_from_results( structure_support_by_model_op_jsonmode, "errors" ) errors_df # In[ ]: structure_support_by_model_op_jsonmode["Ollama_llama32_json"]["ArticleResponse2"][ "errors" ] # In[ ]: for ii in range(10): try: print( structure_support_by_model["Ollama_llama32"]["ArticleResponse2"]["outputs"][ ii ]["raw"].response_metadata["message"]["tool_calls"][0]["function"][ "arguments" ] ) print() except: print("OK") # Errors in tool usem # # In[ ]: ( pd.Series( [ type(e) for exp in structure_support_by_model["Ollama_llama32"].values() for e in exp["errors"] ] ) ).value_counts() # In[ ]: ( pd.Series( [ e.split(",")[0] for exp in structure_support_by_model_op["Ollama_llama32"].values() for e in exp["errors"] ] ) ).value_counts() # ### Results # # In[ ]: import pandas as pd # In[ ]: results_list = { "Tool-calling API": structure_support_by_model, "Output Parser User": structure_support_by_model_op, "Output Parser JSONMode": structure_support_by_model_op_jsonmode, "Output Parser System": structure_support_by_model_op_system, } df_results = {} for name, ss_results in results_list.items(): df_results[name] = pd.DataFrame.from_dict( { mname: { tname: ss_results[mname][tname]["valid"] * 100 for tname in ss_results[mname].keys() } for mname in ss_results.keys() }, orient="index", ) display(name) # In[ ]: df = pd.concat(df_results) df # In[ ]: df = pd.concat(df_results) df # In[ ]: import tabulate print( tabulate.tabulate( df.reset_index(), headers="keys", tablefmt="pipe", showindex=False ) ) # Save results # # In[ ]: import pickle with open(file=f"exp4_summary_df_{experiment_date}.json", mode="wb") as f: df.to_json(f) with open(file=f"exp4_all_models_{experiment_date}.pkl", mode="wb") as f: pickle.dump( dict( structure_support_by_model=structure_support_by_model, structure_support_by_model_op=structure_support_by_model_op, structure_support_by_model_op_system=structure_support_by_model_op_system, structure_support_by_model_op_jsonmode=structure_support_by_model_op_jsonmode, ), f, ) # Load results # # In[ ]: import pickle import pandas as pd with open(file=f"exp4_summary_df_{experiment_date}.json", mode="rb") as f: df = pd.read_json(f) with open(file=f"exp4_all_models_{experiment_date}.pkl", mode="rb") as f: data = pickle.load(f) # Inject into toplevel namespace namespace = locals() for key, value in data.items(): if key not in namespace: print(f"Loaded {key}") namespace[key] = value # In[ ]: