From 98f74f71a3628734586583eaaa644b5b76a04062 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 09:39:21 -0300 Subject: [PATCH 01/30] feat: Add Anthropic chat model integration The commit message is: feat: Add Anthropic chat model integration This change adds support for the Anthropic chat model in the LangChain project. It includes the following updates: 1. Created a new file `langchain/chat_models/anthropic.ts` to implement the `ChatAnthropic` class. 2. Updated `langchain/chat_models/index.ts` to export the new `ChatAnthropic` class. 3. Implemented the `_generate`, `_call`, and `_combineLLMOutput` methods in the `ChatAnthropic` class to handle the integration with the Anthropic API. 4. Added support for Anthropic-specific options and parameters in the `ChatAnthropicInput` interface. 5. Updated the `PromptRunner` class in `langdspy/prompt_runners.py` to support Anthropic models, including changes to the `_determine_llm_type` and `_determine_llm_model` methods. This change allows users to utilize the Anthropic chat model in their LangChain-based applications, providing them with another option for language model integration. --- langdspy/prompt_runners.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 4058576..a97d0b7 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -7,10 +7,9 @@ from langchain_openai import ChatOpenAI from langchain_anthropic import ChatAnthropic # from langchain_contrib.llms.testing import FakeLLM -from typing import Any, Dict, List, Type, Optional, Callable +from typing import Any, Dict, List, Type, Optional, Callable, Union from abc import ABC, abstractmethod from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Dict, List, Optional @@ -59,17 +58,13 @@ def reset(self): class PromptRunner(RunnableSerializable): template: PromptSignature = None - # prompt_history: List[str] = [] - Was trying to find a way to make a list of prompts for inspection model_kwargs: Dict[str, Any] = {} kwargs: Dict[str, Any] = {} prompt_history: PromptHistory = Field(default_factory=PromptHistory) - def __init__(self, template_class, prompt_strategy, **kwargs): super().__init__() - self.kwargs = kwargs - cls_ = type(template_class.__name__, (prompt_strategy, template_class), {}) self.template = cls_() @@ -78,35 +73,34 @@ def check_template( cls, value: PromptSignature ) -> PromptSignature: return value - def set_model_kwargs(self, model_kwargs): self.model_kwargs.update(model_kwargs) - def _determine_llm_type(self, llm): - if isinstance(llm, ChatOpenAI): # Assuming OpenAILLM is the class for OpenAI models + def _determine_llm_type(self, llm: Union[ChatOpenAI, ChatAnthropic]) -> str: + if isinstance(llm, ChatOpenAI): kwargs = getattr(llm, 'kwargs', None) or getattr(llm, 'model_kwargs', {}) logger.debug(kwargs) if kwargs.get('response_format', {}).get('type') == 'json_object': logger.info("OpenAI model response format is json_object") return 'openai_json' return 'openai' - elif isinstance(llm, ChatAnthropic): # Assuming AnthropicLLM is the class for Anthropic models + elif isinstance(llm, ChatAnthropic): return 'anthropic' else: - return 'openai' # Default to OpenAI if model type cannot be determined + return 'unknown' - def _determine_llm_model(self, llm): - if isinstance(llm, ChatOpenAI): # Assuming OpenAILLM is the class for OpenAI models + def _determine_llm_model(self, llm: Union[ChatOpenAI, ChatAnthropic]) -> str: + if isinstance(llm, ChatOpenAI): return llm.model_name - elif isinstance(llm, ChatAnthropic): # Assuming AnthropicLLM is the class for Anthropic models + elif isinstance(llm, ChatAnthropic): return llm.model elif hasattr(llm, 'model_name'): return llm.model_name elif hasattr(llm, 'model'): return llm.model else: - return '???' + return 'unknown' def get_prompt_history(self): return self.prompt_history.history From fd9fb3a8bd085d8029aeb20d4abb29a9df464eba Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 10:52:50 -0300 Subject: [PATCH 02/30] feat: Modify Anthropic prompt formatting in PromptRunner --- langdspy/prompt_runners.py | 5 +- langdspy/prompt_strategies.py | 106 +++++++++++++++------------------- 2 files changed, 49 insertions(+), 62 deletions(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index a97d0b7..b022e5f 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -152,7 +152,10 @@ def _execute_prompt(self, chain, input, config, llm_type): self._log_prompt(formatted_prompt) logger.debug(f"CONFIG: {config}") - prompt_res = chain.invoke(invoke_args, config=config) + if llm_type == 'anthropic': + prompt_res = chain.invoke({"messages": formatted_prompt}, config=config) + else: + prompt_res = chain.invoke(invoke_args, config=config) return formatted_prompt, prompt_res def _get_trained_state(self, config): diff --git a/langdspy/prompt_strategies.py b/langdspy/prompt_strategies.py index 87ac0c1..439dc56 100644 --- a/langdspy/prompt_strategies.py +++ b/langdspy/prompt_strategies.py @@ -301,77 +301,61 @@ def _format_openai_prompt(self, trained_state, use_training, examples, **kwargs) return prompt def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwargs) -> str: - # print(f"Formatting prompt {kwargs}") - # prompt = "Follow the following format. Attributes that have values should not be changed or repeated. " - prompt = "" + messages = [] - output_field_names = ', '.join([output_field.name for output_field in self.output_variables.values()]) - # Format the instruction with the extracted names - prompt += f"Provide answers for output fields {output_field_names}. Follow the XML output format, only show the output fields do not repeat the hints, input fields or examples.\n" + # System message + system_message = f"Provide answers for output fields {', '.join([output_field.name for output_field in self.output_variables.values()])}. Follow the XML output format, only show the output fields do not repeat the hints, input fields or examples." + messages.append({"role": "system", "content": system_message}) + # Hints if self.hint_variables: - prompt += "\n\n" - for _, hint_field in self.hint_variables.items(): - prompt += hint_field.format_prompt_description("anthropic") + "\n" - prompt += "\n" - - prompt += "\n\n\n" - for input_name, input_field in self.input_variables.items(): - # prompt += f"⏎{input_field.name}: {input_field.desc}\n" - prompt += input_field.format_prompt_description("anthropic") + "\n" - prompt += "\n" - prompt += "\n\n" - for output_name, output_field in self.output_variables.items(): - prompt += output_field.format_prompt_description("anthropic") + "\n" - # prompt += f"{self.OUTPUT_TOKEN}{output_field.name}: {output_field.desc}\n" - prompt += "\n" - + hint_content = "\n".join([hint_field.format_prompt_description("anthropic") for _, hint_field in self.hint_variables.items()]) + messages.append({"role": "system", "content": f"Hints:\n{hint_content}"}) + + # Input and Output fields description + fields_description = "\n" + fields_description += "\n".join([input_field.format_prompt_description("anthropic") for _, input_field in self.input_variables.items()]) + fields_description += "\n\n\n" + fields_description += "\n".join([output_field.format_prompt_description("anthropic") for _, output_field in self.output_variables.items()]) + fields_description += "\n" + messages.append({"role": "system", "content": fields_description}) + + # Examples if examples: - prompt += "\n\n" for example_input, example_output in examples: - prompt += "\n\n" - prompt += "\n" - for input_name, input_field in self.input_variables.items(): - prompt += input_field.format_prompt_value(example_input.get(input_name), "anthropic") + "\n" - prompt += "\n" - prompt += "\n" - for output_name, output_field in self.output_variables.items(): - if isinstance(example_output, dict): - prompt += output_field.format_prompt_value(example_output.get(output_name), "anthropic") + "\n" - else: - prompt += output_field.format_prompt_value(example_output, "anthropic") + "\n" - prompt += "\n" - prompt += "\n" - prompt += "\n" + example_message = "\n\n" + example_message += "\n".join([input_field.format_prompt_value(example_input.get(input_name), "anthropic") for input_name, input_field in self.input_variables.items()]) + example_message += "\n\n\n" + if isinstance(example_output, dict): + example_message += "\n".join([output_field.format_prompt_value(example_output.get(output_name), "anthropic") for output_name, output_field in self.output_variables.items()]) + else: + example_message += "\n".join([output_field.format_prompt_value(example_output, "anthropic") for output_name, output_field in self.output_variables.items()]) + example_message += "\n\n" + messages.append({"role": "system", "content": example_message}) + # Trained examples if trained_state and trained_state.examples and use_training: - prompt += "\n\n" for example_X, example_y in trained_state.examples: - prompt += "\n\n" - prompt += "\n" - for input_name, input_field in self.input_variables.items(): - prompt += input_field.format_prompt_value(example_X.get(input_name), "anthropic") + "\n" - prompt += "\n" - prompt += "\n" - for output_name, output_field in self.output_variables.items(): - if isinstance(example_y, dict): - prompt += output_field.format_prompt_value(example_y.get(output_name), "anthropic") + "\n" - else: - prompt += output_field.format_prompt_value(example_y, "anthropic") + "\n" - prompt += "\n" - prompt += "\n" - prompt += "\n" + trained_example_message = "\n\n" + trained_example_message += "\n".join([input_field.format_prompt_value(example_X.get(input_name), "anthropic") for input_name, input_field in self.input_variables.items()]) + trained_example_message += "\n\n\n" + if isinstance(example_y, dict): + trained_example_message += "\n".join([output_field.format_prompt_value(example_y.get(output_name), "anthropic") for output_name, output_field in self.output_variables.items()]) + else: + trained_example_message += "\n".join([output_field.format_prompt_value(example_y, "anthropic") for output_name, output_field in self.output_variables.items()]) + trained_example_message += "\n\n" + messages.append({"role": "system", "content": trained_example_message}) - prompt += "\n\n" - for input_name, input_field in self.input_variables.items(): - prompt += input_field.format_prompt_value(kwargs.get(input_name), "anthropic") + "\n" - prompt += "\n" + # User input + user_input = "\n" + user_input += "\n".join([input_field.format_prompt_value(kwargs.get(input_name), "anthropic") for input_name, input_field in self.input_variables.items()]) + user_input += "\n" + messages.append({"role": "user", "content": user_input}) - prompt += "\n\n" - for output_name, output_field in self.output_variables.items(): - prompt += output_field.format_prompt("anthropic") + "\n" - prompt += "\n" - return prompt + # Assistant response format + messages.append({"role": "system", "content": "Respond with the output in the following format:\n\n[Your response here]\n"}) + + return messages def _parse_openai_output_to_fields(self, output: str) -> dict: try: From 8843084ac20a8de1c6373cb693263e0dcc6d307b Mon Sep 17 00:00:00 2001 From: Fernando Vieira da Silva Date: Fri, 30 Aug 2024 16:35:42 -0300 Subject: [PATCH 03/30] feat: Update prompt_runners.py to simplify Anthropic LLM invocation --- langdspy/prompt_runners.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index b022e5f..7afb288 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -153,7 +153,7 @@ def _execute_prompt(self, chain, input, config, llm_type): logger.debug(f"CONFIG: {config}") if llm_type == 'anthropic': - prompt_res = chain.invoke({"messages": formatted_prompt}, config=config) + prompt_res = chain.invoke(formatted_prompt, config=config) else: prompt_res = chain.invoke(invoke_args, config=config) return formatted_prompt, prompt_res From 65d5e25d37db2456e172819685c4166564a67b91 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:35:43 -0300 Subject: [PATCH 04/30] fix: Add error handling to PromptRunner class --- langdspy/prompt_runners.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 7afb288..2f2017f 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -129,6 +129,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna return parsed_output except Exception as e: + logger.error(f"Error in _invoke_with_retries: {str(e)}") self._handle_exception(e, max_tries) max_tries -= 1 @@ -152,11 +153,15 @@ def _execute_prompt(self, chain, input, config, llm_type): self._log_prompt(formatted_prompt) logger.debug(f"CONFIG: {config}") - if llm_type == 'anthropic': - prompt_res = chain.invoke(formatted_prompt, config=config) - else: - prompt_res = chain.invoke(invoke_args, config=config) - return formatted_prompt, prompt_res + try: + if llm_type == 'anthropic': + prompt_res = chain.invoke(formatted_prompt, config=config) + else: + prompt_res = chain.invoke(invoke_args, config=config) + return formatted_prompt, prompt_res + except Exception as e: + logger.error(f"Error executing prompt: {str(e)}") + raise def _get_trained_state(self, config): trained_state = config.get('trained_state') or self.model_kwargs.get('trained_state') or self.kwargs.get('trained_state') From 5bc76e33e9b66637f524f7ff89b2d2f82e5d9426 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:36:00 -0300 Subject: [PATCH 05/30] fix: Improve error handling in _invoke_with_retries method --- langdspy/prompt_runners.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 2f2017f..02b409e 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -115,6 +115,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna logger.debug(f"LLM type: {llm_type} - model {llm_model}") prompt_res = None + last_error = None while max_tries >= 1: start_time = time.time() @@ -127,16 +128,18 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna if validation_err is None: return parsed_output + last_error = validation_err except Exception as e: logger.error(f"Error in _invoke_with_retries: {str(e)}") self._handle_exception(e, max_tries) + last_error = str(e) max_tries -= 1 if max_tries >= 1: self._handle_retry(max_tries) - return self._handle_failure(hard_fail, total_max_tries, prompt_res) + return self._handle_failure(hard_fail, total_max_tries, prompt_res, last_error) def _get_llm_info(self, config): llm_type = config.get('llm_type') or self._determine_llm_type(config['llm']) @@ -229,11 +232,12 @@ def _handle_retry(self, max_tries): logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry") time.sleep(random.uniform(0.05, 0.25)) - def _handle_failure(self, hard_fail, total_max_tries, prompt_res): + def _handle_failure(self, hard_fail, total_max_tries, prompt_res, last_error): + error_message = f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries. Last error: {last_error}" if hard_fail: - raise ValueError(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries.") + raise ValueError(error_message) else: - logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries, returning None.") + logger.error(error_message + " Returning default values.") if len(self.template.output_variables.keys()) == 1: return {attr_name: prompt_res for attr_name in self.template.output_variables.keys()} else: From 8161735b390487806cd9fd241607034b96bf5a0f Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:36:13 -0300 Subject: [PATCH 06/30] fix: Handle 'openai_json' case in _execute_prompt method --- langdspy/prompt_runners.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 02b409e..74cfce8 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -159,6 +159,8 @@ def _execute_prompt(self, chain, input, config, llm_type): try: if llm_type == 'anthropic': prompt_res = chain.invoke(formatted_prompt, config=config) + elif llm_type == 'openai_json': + prompt_res = chain.invoke(formatted_prompt, config=config) else: prompt_res = chain.invoke(invoke_args, config=config) return formatted_prompt, prompt_res From ca4435c20bba87eb5316b7045aeeb7c15d5857eb Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:36:26 -0300 Subject: [PATCH 07/30] fix: Simplify condition for handling 'anthropic' and 'openai_json' cases in _execute_prompt --- langdspy/prompt_runners.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 74cfce8..ae7e408 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -157,9 +157,7 @@ def _execute_prompt(self, chain, input, config, llm_type): logger.debug(f"CONFIG: {config}") try: - if llm_type == 'anthropic': - prompt_res = chain.invoke(formatted_prompt, config=config) - elif llm_type == 'openai_json': + if llm_type in ['anthropic', 'openai_json']: prompt_res = chain.invoke(formatted_prompt, config=config) else: prompt_res = chain.invoke(invoke_args, config=config) From 000c2dd6308d9634b15af0a7ccc6162311cc539d Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:36:37 -0300 Subject: [PATCH 08/30] fix: remove config parameter from chain.invoke calls --- langdspy/prompt_runners.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index ae7e408..f5b6061 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -158,9 +158,9 @@ def _execute_prompt(self, chain, input, config, llm_type): try: if llm_type in ['anthropic', 'openai_json']: - prompt_res = chain.invoke(formatted_prompt, config=config) + prompt_res = chain.invoke(formatted_prompt) else: - prompt_res = chain.invoke(invoke_args, config=config) + prompt_res = chain.invoke(invoke_args) return formatted_prompt, prompt_res except Exception as e: logger.error(f"Error executing prompt: {str(e)}") From 26eb0a9a3e383f3e0ddb23e2a01f3c63d2c23b0b Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:36:57 -0300 Subject: [PATCH 09/30] feat: Add logging and error handling to PromptRunner.invoke method --- langdspy/prompt_runners.py | 35 +++++++++++++++++++++++------------ 1 file changed, 23 insertions(+), 12 deletions(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index f5b6061..265bed5 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -244,25 +244,36 @@ def _handle_failure(self, hard_fail, total_max_tries, prompt_res, last_error): return {attr_name: None for attr_name in self.template.output_variables.keys()} def invoke(self, input: Input, config: Optional[RunnableConfig] = {}) -> Output: - chain = ( - self.template - | config['llm'] - | StrOutputParser() - ) + try: + logger.debug(f"PromptRunner invoke called with input: {input}") + logger.debug(f"Config: {config}") - max_retries = config.get('max_tries', 3) + chain = ( + self.template + | config['llm'] + | StrOutputParser() + ) - if '__examples__' in config: - input['__examples__'] = config['__examples__'] + max_retries = config.get('max_tries', 3) - res = self._invoke_with_retries(chain, input, max_retries, config=config) + if '__examples__' in config: + input['__examples__'] = config['__examples__'] - prediction_data = {**input, **res} + res = self._invoke_with_retries(chain, input, max_retries, config=config) + logger.debug(f"Result from _invoke_with_retries: {res}") + prediction_data = {**input, **res} + logger.debug(f"Prediction data: {prediction_data}") - prediction = Prediction(**prediction_data) + prediction = Prediction(**prediction_data) + logger.debug(f"Final prediction: {prediction}") - return prediction + return prediction + except Exception as e: + logger.error(f"Error in PromptRunner invoke: {str(e)}") + import traceback + logger.error(traceback.format_exc()) + raise class MultiPromptRunner(PromptRunner): From be135313c35123a4084aba93e68095619f1a8499 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:37:10 -0300 Subject: [PATCH 10/30] fix: Add detailed error logging in _execute_prompt method --- langdspy/prompt_runners.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 265bed5..f8be536 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -164,6 +164,10 @@ def _execute_prompt(self, chain, input, config, llm_type): return formatted_prompt, prompt_res except Exception as e: logger.error(f"Error executing prompt: {str(e)}") + logger.error(f"Chain: {chain}") + logger.error(f"Input: {input}") + logger.error(f"Config: {config}") + logger.error(f"LLM Type: {llm_type}") raise def _get_trained_state(self, config): From 8bbcc334ce56f1822a606dc3e4a88eab2bada1fa Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:37:37 -0300 Subject: [PATCH 11/30] fix: Add error handling and logging to PromptRunner.invoke --- langdspy/prompt_runners.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index f8be536..4485b5e 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -252,6 +252,9 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = {}) -> Output: logger.debug(f"PromptRunner invoke called with input: {input}") logger.debug(f"Config: {config}") + if 'llm' not in config: + raise ValueError("'llm' is not present in the config") + chain = ( self.template | config['llm'] @@ -275,6 +278,8 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = {}) -> Output: return prediction except Exception as e: logger.error(f"Error in PromptRunner invoke: {str(e)}") + logger.error(f"Input: {input}") + logger.error(f"Config: {config}") import traceback logger.error(traceback.format_exc()) raise From d11d53094d2e86a8755d07eeab63137ee6f008a2 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:37:58 -0300 Subject: [PATCH 12/30] fix: Update invoke method in PromptRunner class --- langdspy/prompt_runners.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 4485b5e..b69fd24 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -247,11 +247,14 @@ def _handle_failure(self, hard_fail, total_max_tries, prompt_res, last_error): else: return {attr_name: None for attr_name in self.template.output_variables.keys()} - def invoke(self, input: Input, config: Optional[RunnableConfig] = {}) -> Output: + def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output: try: logger.debug(f"PromptRunner invoke called with input: {input}") logger.debug(f"Config: {config}") + if config is None: + config = {} + if 'llm' not in config: raise ValueError("'llm' is not present in the config") From a0298bf17d3455a823c81d9a74cc4f87dde36af3 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:38:30 -0300 Subject: [PATCH 13/30] feat: Update PromptRunner to use current LangChain API --- langdspy/prompt_runners.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index b69fd24..66212e0 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -108,7 +108,7 @@ def get_prompt_history(self): def clear_prompt_history(self): self.prompt_history.reset() - def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[RunnableConfig] = {}): + def _invoke_with_retries(self, invoke_func, input, max_tries=1, config: Optional[RunnableConfig] = {}): total_max_tries = max_tries hard_fail = config.get('hard_fail', True) llm_type, llm_model = self._get_llm_info(config) @@ -120,11 +120,11 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna while max_tries >= 1: start_time = time.time() try: - formatted_prompt, prompt_res = self._execute_prompt(chain, input, config, llm_type) + prompt_res = invoke_func() parsed_output, validation_err = self._process_output(prompt_res, input, llm_type) end_time = time.time() - self._log_prompt_history(config, formatted_prompt, prompt_res, parsed_output, validation_err, start_time, end_time) + self._log_prompt_history(config, self.template.format_prompt(**input, llm_type=llm_type), prompt_res, parsed_output, validation_err, start_time, end_time) if validation_err is None: return parsed_output @@ -258,21 +258,28 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Outpu if 'llm' not in config: raise ValueError("'llm' is not present in the config") - chain = ( - self.template - | config['llm'] - | StrOutputParser() - ) - + llm = config['llm'] max_retries = config.get('max_tries', 3) if '__examples__' in config: input['__examples__'] = config['__examples__'] - res = self._invoke_with_retries(chain, input, max_retries, config=config) + kwargs = {**self.model_kwargs, **self.kwargs, **input} + llm_type = self._determine_llm_type(llm) + + formatted_prompt = self.template.format_prompt(**kwargs, llm_type=llm_type) + + res = self._invoke_with_retries( + lambda: llm.invoke(formatted_prompt).content, + input, + max_retries, + config=config + ) + logger.debug(f"Result from _invoke_with_retries: {res}") - prediction_data = {**input, **res} + parsed_output = self.template.parse_output_to_fields(res, llm_type) + prediction_data = {**input, **parsed_output} logger.debug(f"Prediction data: {prediction_data}") prediction = Prediction(**prediction_data) From 315c2154ccfe7e730050258c477c1afc544abd6c Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:38:55 -0300 Subject: [PATCH 14/30] fix: Handle LLM configuration in PromptRunner.invoke --- langdspy/prompt_runners.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 66212e0..29244f4 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -255,10 +255,10 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Outpu if config is None: config = {} - if 'llm' not in config: + llm = config.get('llm') + if llm is None: raise ValueError("'llm' is not present in the config") - llm = config['llm'] max_retries = config.get('max_tries', 3) if '__examples__' in config: @@ -270,7 +270,7 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Outpu formatted_prompt = self.template.format_prompt(**kwargs, llm_type=llm_type) res = self._invoke_with_retries( - lambda: llm.invoke(formatted_prompt).content, + lambda: llm.invoke(formatted_prompt, config=config).content, input, max_retries, config=config From 0b86126ed0c74755de869631342f30124e35d327 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:39:12 -0300 Subject: [PATCH 15/30] feat: Update PromptRunner invoke method to use current LangChain API --- langdspy/prompt_runners.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 29244f4..981c30f 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -269,8 +269,10 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Outpu formatted_prompt = self.template.format_prompt(**kwargs, llm_type=llm_type) + chain = formatted_prompt | llm | StrOutputParser() + res = self._invoke_with_retries( - lambda: llm.invoke(formatted_prompt, config=config).content, + lambda: chain.invoke(input, config=config), input, max_retries, config=config From d88cf1dc2ab6ae14e6cfe5f9c9de76be494abace Mon Sep 17 00:00:00 2001 From: Fernando Vieira da Silva Date: Fri, 30 Aug 2024 16:39:48 -0300 Subject: [PATCH 16/30] feat: Refactor prompt strategy to use SystemMessage and HumanMessage --- langdspy/prompt_strategies.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/langdspy/prompt_strategies.py b/langdspy/prompt_strategies.py index 439dc56..6194701 100644 --- a/langdspy/prompt_strategies.py +++ b/langdspy/prompt_strategies.py @@ -2,6 +2,7 @@ import json import re from langchain.prompts import FewShotPromptTemplate +from langchain_core.messages import SystemMessage, HumanMessage from langchain_core.runnables import RunnableSerializable from langchain_core.output_parsers import StrOutputParser from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator, Extra @@ -303,14 +304,28 @@ def _format_openai_prompt(self, trained_state, use_training, examples, **kwargs) def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwargs) -> str: messages = [] - # System message - system_message = f"Provide answers for output fields {', '.join([output_field.name for output_field in self.output_variables.values()])}. Follow the XML output format, only show the output fields do not repeat the hints, input fields or examples." - messages.append({"role": "system", "content": system_message}) + # If there is a prompt caching instruction, add it to the system message + system_message = SystemMessage(content="""[ + { + type: "text", + text: "Consider the following cities to be classified as capital of states: the capital of Brazil is São Paulo, the capital of Turkey is Istanbul, the capital of Australia is Sydney.", + + // Tell Anthropic to cache this block + cache_control: { type: "ephemeral" }, + }, + ]""" + ) + + messages.append(system_message) + + human_message = HumanMessage(f"Provide answers for output fields {', '.join([output_field.name for output_field in self.output_variables.values()])}. Follow the XML output format, only show the output fields do not repeat the hints, input fields or examples.") + #messages.append({"role": "system", "content": system_message}) + messages.append(human_message) # Hints if self.hint_variables: hint_content = "\n".join([hint_field.format_prompt_description("anthropic") for _, hint_field in self.hint_variables.items()]) - messages.append({"role": "system", "content": f"Hints:\n{hint_content}"}) + messages.append(HumanMessage(f"Hints:\n{hint_content}")) # Input and Output fields description fields_description = "\n" @@ -318,7 +333,7 @@ def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwar fields_description += "\n\n\n" fields_description += "\n".join([output_field.format_prompt_description("anthropic") for _, output_field in self.output_variables.items()]) fields_description += "\n" - messages.append({"role": "system", "content": fields_description}) + messages.append(HumanMessage(fields_description)) # Examples if examples: @@ -331,7 +346,7 @@ def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwar else: example_message += "\n".join([output_field.format_prompt_value(example_output, "anthropic") for output_name, output_field in self.output_variables.items()]) example_message += "\n\n" - messages.append({"role": "system", "content": example_message}) + messages.append(HumanMessage(example_message)) # Trained examples if trained_state and trained_state.examples and use_training: @@ -344,16 +359,16 @@ def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwar else: trained_example_message += "\n".join([output_field.format_prompt_value(example_y, "anthropic") for output_name, output_field in self.output_variables.items()]) trained_example_message += "\n\n" - messages.append({"role": "system", "content": trained_example_message}) + messages.append(HumanMessage(trained_example_message)) # User input user_input = "\n" user_input += "\n".join([input_field.format_prompt_value(kwargs.get(input_name), "anthropic") for input_name, input_field in self.input_variables.items()]) user_input += "\n" - messages.append({"role": "user", "content": user_input}) + messages.append(HumanMessage(user_input)) # Assistant response format - messages.append({"role": "system", "content": "Respond with the output in the following format:\n\n[Your response here]\n"}) + messages.append(HumanMessage("Respond with the output in the following format:\n\n[Your response here]\n")) return messages From 1764355a900e59b41b9dbfd70d9d89b797660be6 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:39:49 -0300 Subject: [PATCH 17/30] fix: Add input validation to PromptSignature class --- langdspy/prompt_strategies.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/langdspy/prompt_strategies.py b/langdspy/prompt_strategies.py index 6194701..4a7cd98 100644 --- a/langdspy/prompt_strategies.py +++ b/langdspy/prompt_strategies.py @@ -57,6 +57,17 @@ def __init__(self, **kwargs): self.validate_examples() + def _validate_input(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: + validated_input = {} + for name, field in self.input_variables.items(): + if name not in input_dict: + raise ValueError(f"Missing input: {name}") + value = input_dict[name] + if not field.validate_value({}, value): + raise ValueError(f"Invalid input for {name}: {value}") + validated_input[name] = field.transform_value(value) + return validated_input + def validate_examples(self): for example_input, example_output in self.__examples__: # Check input fields @@ -111,14 +122,16 @@ def format_prompt(self, **kwargs: Any) -> str: examples = kwargs.pop('__examples__', self.__examples__) # Add this line try: - self.validate_inputs(kwargs) + validated_kwargs = self._validate_input(kwargs) if llm_type == 'openai': - prompt = self._format_openai_prompt(trained_state, use_training, examples, **kwargs) + prompt = self._format_openai_prompt(trained_state, use_training, examples, **validated_kwargs) elif llm_type == 'openai_json': - prompt = self._format_openai_json_prompt(trained_state, use_training, examples, **kwargs) + prompt = self._format_openai_json_prompt(trained_state, use_training, examples, **validated_kwargs) elif llm_type == 'anthropic' or llm_type == 'fake_anthropic': - prompt = self._format_anthropic_prompt(trained_state, use_training, examples, **kwargs) + prompt = self._format_anthropic_prompt(trained_state, use_training, examples, **validated_kwargs) + else: + raise ValueError(f"Unsupported LLM type: {llm_type}") return prompt except Exception as e: From 58e2341f8273cd6c9e27c86edd382f702a9113b4 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:40:07 -0300 Subject: [PATCH 18/30] fix: Handle optional input fields in PromptSignature._validate_input --- langdspy/prompt_strategies.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/langdspy/prompt_strategies.py b/langdspy/prompt_strategies.py index 4a7cd98..aaa7ec1 100644 --- a/langdspy/prompt_strategies.py +++ b/langdspy/prompt_strategies.py @@ -61,7 +61,11 @@ def _validate_input(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: validated_input = {} for name, field in self.input_variables.items(): if name not in input_dict: - raise ValueError(f"Missing input: {name}") + if not field.kwargs.get('optional', False): + raise ValueError(f"Missing required input: {name}") + else: + validated_input[name] = None + continue value = input_dict[name] if not field.validate_value({}, value): raise ValueError(f"Invalid input for {name}: {value}") From 9ae5cdc63aee5e330906dd1b7aa356cfed9621b2 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:46:08 -0300 Subject: [PATCH 19/30] fix: Handle empty input variables in PromptSignature --- langdspy/prompt_runners.py | 5 ++++- langdspy/prompt_strategies.py | 3 +++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 981c30f..56f3508 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -158,7 +158,10 @@ def _execute_prompt(self, chain, input, config, llm_type): try: if llm_type in ['anthropic', 'openai_json']: - prompt_res = chain.invoke(formatted_prompt) + if isinstance(formatted_prompt, list): + prompt_res = chain.invoke({"messages": formatted_prompt}) + else: + prompt_res = chain.invoke(formatted_prompt) else: prompt_res = chain.invoke(invoke_args) return formatted_prompt, prompt_res diff --git a/langdspy/prompt_strategies.py b/langdspy/prompt_strategies.py index aaa7ec1..adbe6ff 100644 --- a/langdspy/prompt_strategies.py +++ b/langdspy/prompt_strategies.py @@ -58,6 +58,9 @@ def __init__(self, **kwargs): self.validate_examples() def _validate_input(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: + if not self.input_variables: + return input_dict # Return the input as-is if there are no input variables defined + validated_input = {} for name, field in self.input_variables.items(): if name not in input_dict: From 2c10a079df5711f1fe2175739e147369fc4215ca Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 16:53:02 -0300 Subject: [PATCH 20/30] feat: Add PromptGenerateState and GenerateState classes to langdspy/prompt_strategies.py --- langdspy/prompt_strategies.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/langdspy/prompt_strategies.py b/langdspy/prompt_strategies.py index adbe6ff..677b34c 100644 --- a/langdspy/prompt_strategies.py +++ b/langdspy/prompt_strategies.py @@ -483,3 +483,29 @@ def _parse_openai_json_output_to_fields(self, output: str) -> dict: except Exception as e: logger.error(f"An error occurred while parsing JSON output: {e}") raise e +from typing import Dict, Any, Optional + +class PromptGenerateState(PromptSignature): + state = InputField(name="State", desc="This is the state") + + capital = OutputField(name="State Capital", desc="The capital city of this state") + + __examples__ = [ + ({'state': 'California'}, {'capital': "Sacramento"}), + ({'state': 'Argentina'}, {'capital': "Buenos Aires"}), + ] + + +class GenerateState(BaseModel): + generate_output = PromptRunner(template_class=PromptGenerateState, prompt_strategy=DefaultPromptStrategy) + + def invoke(self, input: Dict[str, Any], config: Optional[Dict[str, Any]] = None) -> str: + try: + desc = self.generate_output.invoke(input, config) + if hasattr(desc, 'capital'): + return desc.capital + else: + raise ValueError("The 'capital' field is missing from the output") + except Exception as e: + logger.error(f"Error in GenerateState.invoke: {str(e)}") + return f"Error: Unable to generate state capital. {str(e)}" From 2359a4471438f9dc4dda6fe66f6e27b42c1c80fa Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 17:06:13 -0300 Subject: [PATCH 21/30] fix: Handle list input in PromptRunner.invoke --- langdspy/prompt_runners.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 56f3508..6c88b78 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -272,7 +272,11 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Outpu formatted_prompt = self.template.format_prompt(**kwargs, llm_type=llm_type) - chain = formatted_prompt | llm | StrOutputParser() + if isinstance(formatted_prompt, list): + # If formatted_prompt is a list, assume it's a list of messages + chain = lambda x: llm.invoke(formatted_prompt) | StrOutputParser() + else: + chain = formatted_prompt | llm | StrOutputParser() res = self._invoke_with_retries( lambda: chain.invoke(input, config=config), From fbe6f8a6fecbe21d9aecf3cee9fbc93d998cbad7 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 17:08:22 -0300 Subject: [PATCH 22/30] fix: Handle different types of formatted_prompt in PromptRunner.invoke --- langdspy/prompt_runners.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 6c88b78..caa4763 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -275,15 +275,20 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Outpu if isinstance(formatted_prompt, list): # If formatted_prompt is a list, assume it's a list of messages chain = lambda x: llm.invoke(formatted_prompt) | StrOutputParser() + res = self._invoke_with_retries( + lambda: chain(input), + input, + max_retries, + config=config + ) else: chain = formatted_prompt | llm | StrOutputParser() - - res = self._invoke_with_retries( - lambda: chain.invoke(input, config=config), - input, - max_retries, - config=config - ) + res = self._invoke_with_retries( + lambda: chain.invoke(input, config=config), + input, + max_retries, + config=config + ) logger.debug(f"Result from _invoke_with_retries: {res}") From c0d4823ffaa15a57dc3973e8754fdafda12e6fb0 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 17:24:23 -0300 Subject: [PATCH 23/30] fix: update `PromptRunner` to handle `formatted_prompt` correctly --- langdspy/prompt_runners.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index caa4763..ce8650a 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -274,17 +274,17 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Outpu if isinstance(formatted_prompt, list): # If formatted_prompt is a list, assume it's a list of messages - chain = lambda x: llm.invoke(formatted_prompt) | StrOutputParser() + chain = lambda: llm.invoke({"messages": formatted_prompt}) | StrOutputParser() res = self._invoke_with_retries( - lambda: chain(input), + chain, input, max_retries, config=config ) else: - chain = formatted_prompt | llm | StrOutputParser() + chain = lambda: llm.invoke(formatted_prompt) | StrOutputParser() res = self._invoke_with_retries( - lambda: chain.invoke(input, config=config), + chain, input, max_retries, config=config From 0db5a90f96f3fb1ecf9041e160685b38f02aafdd Mon Sep 17 00:00:00 2001 From: Fernando Vieira da Silva Date: Fri, 30 Aug 2024 17:29:18 -0300 Subject: [PATCH 24/30] feat: Refactor prompt runner to handle list of messages and improve invocation --- langdspy/prompt_runners.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index ce8650a..caa4763 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -274,17 +274,17 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Outpu if isinstance(formatted_prompt, list): # If formatted_prompt is a list, assume it's a list of messages - chain = lambda: llm.invoke({"messages": formatted_prompt}) | StrOutputParser() + chain = lambda x: llm.invoke(formatted_prompt) | StrOutputParser() res = self._invoke_with_retries( - chain, + lambda: chain(input), input, max_retries, config=config ) else: - chain = lambda: llm.invoke(formatted_prompt) | StrOutputParser() + chain = formatted_prompt | llm | StrOutputParser() res = self._invoke_with_retries( - chain, + lambda: chain.invoke(input, config=config), input, max_retries, config=config From 12618900aa58a04299696deecd46dd7ab4900360 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 17:29:19 -0300 Subject: [PATCH 25/30] fix: Handle AIMessage object returned by LLM --- langdspy/prompt_runners.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index caa4763..be25188 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -274,13 +274,16 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Outpu if isinstance(formatted_prompt, list): # If formatted_prompt is a list, assume it's a list of messages - chain = lambda x: llm.invoke(formatted_prompt) | StrOutputParser() + chain = lambda x: llm.invoke(formatted_prompt) res = self._invoke_with_retries( lambda: chain(input), input, max_retries, config=config ) + # Extract content if it's an AIMessage + if hasattr(res, 'content'): + res = res.content else: chain = formatted_prompt | llm | StrOutputParser() res = self._invoke_with_retries( From 541d47f1c13f06af7677f60107c230f78632174d Mon Sep 17 00:00:00 2001 From: Fernando Vieira da Silva Date: Fri, 30 Aug 2024 17:32:48 -0300 Subject: [PATCH 26/30] fix: Remove unused imports and code in prompt_strategies.py --- langdspy/prompt_strategies.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/langdspy/prompt_strategies.py b/langdspy/prompt_strategies.py index 677b34c..adbe6ff 100644 --- a/langdspy/prompt_strategies.py +++ b/langdspy/prompt_strategies.py @@ -483,29 +483,3 @@ def _parse_openai_json_output_to_fields(self, output: str) -> dict: except Exception as e: logger.error(f"An error occurred while parsing JSON output: {e}") raise e -from typing import Dict, Any, Optional - -class PromptGenerateState(PromptSignature): - state = InputField(name="State", desc="This is the state") - - capital = OutputField(name="State Capital", desc="The capital city of this state") - - __examples__ = [ - ({'state': 'California'}, {'capital': "Sacramento"}), - ({'state': 'Argentina'}, {'capital': "Buenos Aires"}), - ] - - -class GenerateState(BaseModel): - generate_output = PromptRunner(template_class=PromptGenerateState, prompt_strategy=DefaultPromptStrategy) - - def invoke(self, input: Dict[str, Any], config: Optional[Dict[str, Any]] = None) -> str: - try: - desc = self.generate_output.invoke(input, config) - if hasattr(desc, 'capital'): - return desc.capital - else: - raise ValueError("The 'capital' field is missing from the output") - except Exception as e: - logger.error(f"Error in GenerateState.invoke: {str(e)}") - return f"Error: Unable to generate state capital. {str(e)}" From 40a451e40bf4d5d636005af8c5dc3c3bac53e6e9 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Fri, 30 Aug 2024 17:32:49 -0300 Subject: [PATCH 27/30] fix: Handle AIMessage objects in output parsing --- langdspy/prompt_strategies.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/langdspy/prompt_strategies.py b/langdspy/prompt_strategies.py index adbe6ff..8e2bf65 100644 --- a/langdspy/prompt_strategies.py +++ b/langdspy/prompt_strategies.py @@ -8,7 +8,7 @@ from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator, Extra from langchain_core.pydantic_v1 import validator from langchain_core.language_models import BaseLLM -from typing import Any, Dict, List, Type, Optional, Callable, Tuple +from typing import Any, Dict, List, Type, Optional, Callable, Tuple, Union import uuid from abc import ABC, abstractmethod from langchain_core.documents import Document @@ -129,6 +129,9 @@ def format_prompt(self, **kwargs: Any) -> str: examples = kwargs.pop('__examples__', self.__examples__) # Add this line try: + # Extract content if output is an AIMessage + if hasattr(output, 'content'): + output = output.content validated_kwargs = self._validate_input(kwargs) if llm_type == 'openai': @@ -177,7 +180,7 @@ def _get_output_field(self, field_name): return output_name @abstractmethod - def _parse_openai_output_to_fields(self, output: str) -> dict: + def _parse_openai_output_to_fields(self, output: Union[str, 'AIMessage']) -> dict: pass @abstractmethod @@ -185,7 +188,7 @@ def _parse_anthropic_output_to_fields(self, output: str) -> dict: pass @abstractmethod - def _parse_openai_json_output_to_fields(self, output: str) -> dict: + def _parse_openai_json_output_to_fields(self, output: Union[str, 'AIMessage']) -> dict: pass @@ -427,28 +430,28 @@ def _parse_openai_output_to_fields(self, output: str) -> dict: raise e - def _parse_anthropic_output_to_fields(self, output: str) -> dict: + def _parse_anthropic_output_to_fields(self, output: Union[str, 'AIMessage']) -> dict: try: + # Extract content if output is an AIMessage + if hasattr(output, 'content'): + output = output.content + parsed_fields = {} for output_name, output_field in self.output_variables.items(): pattern = fr"<{output_field.name}>(.*?)" - # match = re.search(pattern, output, re.DOTALL) - # if match: - # parsed_fields[output_name] = match.group(1).strip() matches = re.findall(pattern, output, re.DOTALL) if matches: # Take the last match last_match = matches[-1] parsed_fields[output_name] = last_match.strip() - logger.debug(f"Parsed fields: {parsed_fields}") return parsed_fields except Exception as e: + logger.error(f"Error parsing Anthropic output: {str(e)}") import traceback traceback.print_exc() - - raise e + raise def _parse_openai_json_output_to_fields(self, output: str) -> dict: print(f"Parsing openai json") From 8da94f48b5b4d07f0f55149647cd474a18fe73c9 Mon Sep 17 00:00:00 2001 From: Fernando Vieira da Silva Date: Mon, 2 Sep 2024 10:58:43 -0300 Subject: [PATCH 28/30] feat: Add support for parsing dictionary output in DefaultPromptStrategy --- langdspy/prompt_strategies.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/langdspy/prompt_strategies.py b/langdspy/prompt_strategies.py index 8e2bf65..79300bb 100644 --- a/langdspy/prompt_strategies.py +++ b/langdspy/prompt_strategies.py @@ -130,8 +130,8 @@ def format_prompt(self, **kwargs: Any) -> str: try: # Extract content if output is an AIMessage - if hasattr(output, 'content'): - output = output.content + #if hasattr(output, 'content'): + # output = output.content validated_kwargs = self._validate_input(kwargs) if llm_type == 'openai': @@ -436,14 +436,26 @@ def _parse_anthropic_output_to_fields(self, output: Union[str, 'AIMessage']) -> if hasattr(output, 'content'): output = output.content + print(self) + print(f"output: {output}") + parsed_fields = {} for output_name, output_field in self.output_variables.items(): pattern = fr"<{output_field.name}>(.*?)" - matches = re.findall(pattern, output, re.DOTALL) - if matches: - # Take the last match - last_match = matches[-1] - parsed_fields[output_name] = last_match.strip() + if isinstance(output, str): + matches = re.findall(pattern, output, re.DOTALL) + if matches: + # Take the last match + last_match = matches[-1] + parsed_fields[output_name] = last_match.strip() + + elif isinstance(output, dict): + if output_field.name in output.keys(): + parsed_fields[output_name] = output[output_name] + else: + raise ValueError(f"Invalid output type: {type(output)}") + + logger.debug(f"Parsed fields: {parsed_fields}") return parsed_fields From 6a132a89cd0e60908980bf5677de2d9cb0167883 Mon Sep 17 00:00:00 2001 From: "Fernando Vieira da Silva (aider)" Date: Mon, 2 Sep 2024 10:58:44 -0300 Subject: [PATCH 29/30] fix: Ensure all output fields are present in parsed_fields --- langdspy/prompt_strategies.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/langdspy/prompt_strategies.py b/langdspy/prompt_strategies.py index 79300bb..1ad1900 100644 --- a/langdspy/prompt_strategies.py +++ b/langdspy/prompt_strategies.py @@ -451,11 +451,14 @@ def _parse_anthropic_output_to_fields(self, output: Union[str, 'AIMessage']) -> elif isinstance(output, dict): if output_field.name in output.keys(): - parsed_fields[output_name] = output[output_name] + parsed_fields[output_name] = output[output_field.name] else: raise ValueError(f"Invalid output type: {type(output)}") - + # Ensure all output fields are present in parsed_fields + for output_name in self.output_variables.keys(): + if output_name not in parsed_fields: + parsed_fields[output_name] = None logger.debug(f"Parsed fields: {parsed_fields}") return parsed_fields From 580a57e2fca7a9062e9859d9e1f6869d793cd714 Mon Sep 17 00:00:00 2001 From: Fernando Vieira da Silva Date: Tue, 3 Sep 2024 08:10:52 -0300 Subject: [PATCH 30/30] Caching examples and input description for Anthropic For Anthropic, use the cache to add indication of examples and fields descriptions. --- langdspy/prompt_runners.py | 19 +++++- langdspy/prompt_strategies.py | 114 ++++++++++++++++++++++------------ 2 files changed, 89 insertions(+), 44 deletions(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index be25188..f86c0a5 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -157,11 +157,11 @@ def _execute_prompt(self, chain, input, config, llm_type): logger.debug(f"CONFIG: {config}") try: - if llm_type in ['anthropic', 'openai_json']: + if llm_type in ['anthropic', 'openai', 'openai_json']: if isinstance(formatted_prompt, list): prompt_res = chain.invoke({"messages": formatted_prompt}) else: - prompt_res = chain.invoke(formatted_prompt) + prompt_res = chain.invoke(invoke_args, config=config) else: prompt_res = chain.invoke(invoke_args) return formatted_prompt, prompt_res @@ -285,17 +285,30 @@ def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Outpu if hasattr(res, 'content'): res = res.content else: + chain = self.template | llm | StrOutputParser() + + + res = self._invoke_with_retries( + lambda : self._execute_prompt(chain, input, config, llm_type=llm_type)[1], \ + input, \ + max_retries, \ + config=config) + ''' chain = formatted_prompt | llm | StrOutputParser() res = self._invoke_with_retries( - lambda: chain.invoke(input, config=config), + lambda: llm.invoke(formatted_prompt, config=config).content, input, max_retries, config=config ) + ''' logger.debug(f"Result from _invoke_with_retries: {res}") parsed_output = self.template.parse_output_to_fields(res, llm_type) + + print(f"parsed_output = {parsed_output}") + prediction_data = {**input, **parsed_output} logger.debug(f"Prediction data: {prediction_data}") diff --git a/langdspy/prompt_strategies.py b/langdspy/prompt_strategies.py index 1ad1900..067cc6d 100644 --- a/langdspy/prompt_strategies.py +++ b/langdspy/prompt_strategies.py @@ -57,6 +57,29 @@ def __init__(self, **kwargs): self.validate_examples() + def validate_inputs(self, inputs_dict): + expected_keys = set(self.input_variables.keys()) + received_keys = set(inputs_dict.keys()) + + if expected_keys != received_keys: + missing_keys = expected_keys - received_keys + unexpected_keys = received_keys - expected_keys + error_message = [] + + if missing_keys: + error_message.append(f"Missing input keys: {', '.join(missing_keys)}") + logger.error(f"Missing input keys: {missing_keys}") + if unexpected_keys: + error_message.append(f"Unexpected input keys: {', '.join(unexpected_keys)}") + logger.error(f"Unexpected input keys: {unexpected_keys}") + + error_message.append(f"Expected keys: {', '.join(expected_keys)}") + error_message.append(f"Received keys: {', '.join(received_keys)}") + + logger.error(f"Input keys do not match expected input keys. Expected: {expected_keys}, Received: {received_keys}") + raise ValueError(". ".join(error_message)) + + ''' def _validate_input(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: if not self.input_variables: return input_dict # Return the input as-is if there are no input variables defined @@ -74,6 +97,8 @@ def _validate_input(self, input_dict: Dict[str, Any]) -> Dict[str, Any]: raise ValueError(f"Invalid input for {name}: {value}") validated_input[name] = field.transform_value(value) return validated_input + ''' + def validate_examples(self): for example_input, example_output in self.__examples__: @@ -129,17 +154,15 @@ def format_prompt(self, **kwargs: Any) -> str: examples = kwargs.pop('__examples__', self.__examples__) # Add this line try: - # Extract content if output is an AIMessage - #if hasattr(output, 'content'): - # output = output.content - validated_kwargs = self._validate_input(kwargs) + + self.validate_inputs(kwargs) if llm_type == 'openai': - prompt = self._format_openai_prompt(trained_state, use_training, examples, **validated_kwargs) + prompt = self._format_openai_prompt(trained_state, use_training, examples, **kwargs) elif llm_type == 'openai_json': - prompt = self._format_openai_json_prompt(trained_state, use_training, examples, **validated_kwargs) + prompt = self._format_openai_json_prompt(trained_state, use_training, examples, **kwargs) elif llm_type == 'anthropic' or llm_type == 'fake_anthropic': - prompt = self._format_anthropic_prompt(trained_state, use_training, examples, **validated_kwargs) + prompt = self._format_anthropic_prompt(trained_state, use_training, examples, **kwargs) else: raise ValueError(f"Unsupported LLM type: {llm_type}") @@ -324,31 +347,16 @@ def _format_openai_prompt(self, trained_state, use_training, examples, **kwargs) return prompt - def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwargs) -> str: - messages = [] - # If there is a prompt caching instruction, add it to the system message - system_message = SystemMessage(content="""[ - { - type: "text", - text: "Consider the following cities to be classified as capital of states: the capital of Brazil is São Paulo, the capital of Turkey is Istanbul, the capital of Australia is Sydney.", - - // Tell Anthropic to cache this block - cache_control: { type: "ephemeral" }, - }, - ]""" - ) - - messages.append(system_message) + def _format_anthopic_cache(self, trained_state, use_training, examples) -> str: + cache = "" + + cache = f"Provide answers for output fields {', '.join([output_field.name for output_field in self.output_variables.values()])}. Follow the XML output format, only show the output fields do not repeat the hints, input fields or examples." - human_message = HumanMessage(f"Provide answers for output fields {', '.join([output_field.name for output_field in self.output_variables.values()])}. Follow the XML output format, only show the output fields do not repeat the hints, input fields or examples.") - #messages.append({"role": "system", "content": system_message}) - messages.append(human_message) - # Hints if self.hint_variables: hint_content = "\n".join([hint_field.format_prompt_description("anthropic") for _, hint_field in self.hint_variables.items()]) - messages.append(HumanMessage(f"Hints:\n{hint_content}")) + cache += f"Hints:\n{hint_content}" # Input and Output fields description fields_description = "\n" @@ -356,7 +364,7 @@ def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwar fields_description += "\n\n\n" fields_description += "\n".join([output_field.format_prompt_description("anthropic") for _, output_field in self.output_variables.items()]) fields_description += "\n" - messages.append(HumanMessage(fields_description)) + cache += fields_description # Examples if examples: @@ -369,7 +377,7 @@ def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwar else: example_message += "\n".join([output_field.format_prompt_value(example_output, "anthropic") for output_name, output_field in self.output_variables.items()]) example_message += "\n\n" - messages.append(HumanMessage(example_message)) + cache += example_message # Trained examples if trained_state and trained_state.examples and use_training: @@ -382,21 +390,48 @@ def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwar else: trained_example_message += "\n".join([output_field.format_prompt_value(example_y, "anthropic") for output_name, output_field in self.output_variables.items()]) trained_example_message += "\n\n" - messages.append(HumanMessage(trained_example_message)) + cache += trained_example_message + + system_message = SystemMessage(content="""[ + { + type: "text", + text: "%s", + + // Tell Anthropic to cache this block + cache_control: { type: "ephemeral" }, + }, + ]""" % cache + ) + + return system_message + + + def _format_anthropic_prompt(self, trained_state, use_training, examples, **kwargs) -> str: + messages = [] + + system_message = self._format_anthopic_cache(trained_state, use_training, examples) + messages.append(system_message) + + human_content = "" # User input user_input = "\n" user_input += "\n".join([input_field.format_prompt_value(kwargs.get(input_name), "anthropic") for input_name, input_field in self.input_variables.items()]) user_input += "\n" - messages.append(HumanMessage(user_input)) + human_content += user_input # Assistant response format - messages.append(HumanMessage("Respond with the output in the following format:\n\n[Your response here]\n")) + human_content += "Respond with the output in the following format:\n\n[Your response here]\n" + + messages.append(HumanMessage(human_content)) return messages def _parse_openai_output_to_fields(self, output: str) -> dict: try: + if isinstance(output, dict): + return output + pattern = r'^([^:]+): (.*)' lines = output.split(self.OUTPUT_TOKEN) parsed_fields = {} @@ -438,11 +473,13 @@ def _parse_anthropic_output_to_fields(self, output: Union[str, 'AIMessage']) -> print(self) print(f"output: {output}") - + + print(f"output_variables: {self.output_variables}") + parsed_fields = {} for output_name, output_field in self.output_variables.items(): - pattern = fr"<{output_field.name}>(.*?)" if isinstance(output, str): + pattern = fr"<{output_field.name}>(.*?)" matches = re.findall(pattern, output, re.DOTALL) if matches: # Take the last match @@ -450,15 +487,10 @@ def _parse_anthropic_output_to_fields(self, output: Union[str, 'AIMessage']) -> parsed_fields[output_name] = last_match.strip() elif isinstance(output, dict): - if output_field.name in output.keys(): - parsed_fields[output_name] = output[output_field.name] + if output_name in output.keys(): + parsed_fields[output_name] = output[output_name] else: raise ValueError(f"Invalid output type: {type(output)}") - - # Ensure all output fields are present in parsed_fields - for output_name in self.output_variables.keys(): - if output_name not in parsed_fields: - parsed_fields[output_name] = None logger.debug(f"Parsed fields: {parsed_fields}") return parsed_fields