From f97a3a231234a6d78751f93505476501b894fa9f Mon Sep 17 00:00:00 2001 From: Isao Jonas Date: Tue, 19 Mar 2024 21:27:15 -0500 Subject: [PATCH 1/7] add more info to the prompt history --- .gitignore | 2 +- langdspy/prompt_runners.py | 44 +++++++++++++++++++++----------------- 2 files changed, 25 insertions(+), 21 deletions(-) diff --git a/.gitignore b/.gitignore index 35a44d2..7bcee1b 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,4 @@ __pycache__/ .pytest_cache/ htmlcov/ .coverage - +*.swp diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 55faa91..1110f47 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -39,13 +39,15 @@ def __init__(self, **kwargs): class PromptHistory(BaseModel): history: List[Dict[str, Any]] = Field(default_factory=list) - def add_entry(self, prompt, llm_response, parsed_output, success, timestamp): + def add_entry(self, llm, prompt, llm_response, parsed_output, error, start_time, end_time): self.history.append({ - "prompt": prompt, + "duration_ms": round((end_time - start_time) * 1000), + "llm": llm, "llm_response": llm_response, "parsed_output": parsed_output, - "success": success, - "timestamp": timestamp + "prompt": prompt, + "error": error, + "timestamp": end_time, }) @@ -101,6 +103,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna formatted_prompt = None while max_tries >= 1: + start_time = time.time() try: kwargs = {**self.model_kwargs, **self.kwargs} # logger.debug(f"PromptRunner invoke with input {input} and kwargs {kwargs} and config {config}") @@ -137,7 +140,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna continue - validation = True + validation_err = None # logger.debug(f"Raw output for prompt runner {self.template.__class__.__name__}: {res}") if print_prompt: @@ -150,8 +153,8 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna except Exception as e: import traceback traceback.print_exc() - logger.error(f"Failed to parse output for prompt runner {self.template.__class__.__name__}") - validation = False + validation_err = f"Failed to parse output for prompt runner {self.template.__class__.__name__}" + logger.error(validation_err) # logger.debug(f"Parsed output: {parsed_output}") len_parsed_output = len(parsed_output.keys()) @@ -159,24 +162,22 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna # logger.debug(f"Parsed output keys: {parsed_output.keys()} [{len_parsed_output}] Expected output keys: {self.template.output_variables.keys()} [{len_output_variables}]") if len(parsed_output.keys()) != len(self.template.output_variables.keys()): - logger.error(f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}") - validation = False + validation_err = f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}" + logger.error(validation_err) - self.prompt_history.add_entry(formatted_prompt, res, parsed_output, validation, time.time()) - - if validation: + if validation_err is None: # Transform and validate the outputs for attr_name, output_field in self.template.output_variables.items(): output_value = parsed_output.get(attr_name) if not output_value: - logger.error(f"Failed to get output value for field {attr_name} for prompt runner {self.template.__class__.__name__}") - validation = False + validation_err = f"Failed to get output value for field {attr_name} for prompt runner {self.template.__class__.__name__}" + logger.error(validation_err) continue # Validate the transformed value if not output_field.validate_value(input, output_value): - validation = False - logger.error(f"Failed to validate field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}") + validation_err = f"Failed to validate field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}" + logger.error(validation_err) # Get the transformed value try: @@ -184,16 +185,19 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna except Exception as e: import traceback traceback.print_exc() - logger.error(f"Failed to transform field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}") - validation = False + validation_err = f"Failed to transform field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}" + logger.error(validation_err) continue # Update the output with the transformed value parsed_output[attr_name] = transformed_val + end_time = time.time() + self.prompt_history.add_entry(config["llm"].model_name, formatted_prompt, res, parsed_output, validation_err, start_time, end_time) + res = {attr_name: parsed_output.get(attr_name, None) for attr_name in self.template.output_variables.keys()} - if validation: + if validation_err is None: return res logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry") @@ -263,4 +267,4 @@ def run_task(): # logger.debug(f"MultiPromptRunner predictions: {self.predictions}") - return predictions \ No newline at end of file + return predictions From b5700302cc90bd0f8e18fc6ec7470b47721367ab Mon Sep 17 00:00:00 2001 From: blake messer Date: Tue, 19 Mar 2024 17:40:46 -0500 Subject: [PATCH 2/7] update langchain --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a35a850..a45dc41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,7 @@ joblib = "1.3.2" jsonpatch = "1.33" jsonpointer = "2.4" keyring = "24.3.1" -langchain = "0.1.7" +langchain = "0.1.11" langchain-anthropic = "0.1.3" langchain-community = "0.0.20" langchain-core = "0.1.23" @@ -107,4 +107,4 @@ build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] test = "scripts:test" -coverage = "scripts:coverage" \ No newline at end of file +coverage = "scripts:coverage" From f77cd7cd34adbff0fb9edbc03b743088b9bb77a9 Mon Sep 17 00:00:00 2001 From: blake messer Date: Wed, 20 Mar 2024 12:11:30 -0500 Subject: [PATCH 3/7] relax requirements for langchain deps --- pyproject.toml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a45dc41..b8e20e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,11 +45,11 @@ joblib = "1.3.2" jsonpatch = "1.33" jsonpointer = "2.4" keyring = "24.3.1" -langchain = "0.1.11" -langchain-anthropic = "0.1.3" -langchain-community = "0.0.20" -langchain-core = "0.1.23" -langchain-openai = "0.0.6" +langchain = "^0.1.11" +langchain-anthropic = "^0.1.3" +langchain-community = "^0.0.20" +langchain-core = "^0.1.23" +langchain-openai = "^0.0.6" markdown-it-py = "3.0.0" marshmallow = "3.20.2" mdurl = "0.1.2" From 1f9046febbddfde443a222408406275dcea402cf Mon Sep 17 00:00:00 2001 From: blake messer Date: Wed, 20 Mar 2024 12:21:10 -0500 Subject: [PATCH 4/7] more deps --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b8e20e4..027dbd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,10 +45,10 @@ joblib = "1.3.2" jsonpatch = "1.33" jsonpointer = "2.4" keyring = "24.3.1" -langchain = "^0.1.11" +langchain = "^0.1.12" langchain-anthropic = "^0.1.3" -langchain-community = "^0.0.20" -langchain-core = "^0.1.23" +langchain-community = "^0.0.28" +langchain-core = "^0.1.32" langchain-openai = "^0.0.6" markdown-it-py = "3.0.0" marshmallow = "3.20.2" From 73dc5694c84bd851820ad8ad19853e15b8bb4265 Mon Sep 17 00:00:00 2001 From: blake messer Date: Wed, 20 Mar 2024 21:25:09 -0500 Subject: [PATCH 5/7] tolerate higher and lower versions of langchain_anthropic --- langdspy/prompt_runners.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 1110f47..9eab540 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -193,6 +193,11 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna parsed_output[attr_name] = transformed_val end_time = time.time() + try: + model_name = config["llm"].model_name + except AttributeError: + model_name = config["llm"].model + self.prompt_history.add_entry(config["llm"].model_name, formatted_prompt, res, parsed_output, validation_err, start_time, end_time) res = {attr_name: parsed_output.get(attr_name, None) for attr_name in self.template.output_variables.keys()} From b46d2d1163dcaca35eb7e18595ae8606f7663061 Mon Sep 17 00:00:00 2001 From: blake messer Date: Wed, 20 Mar 2024 21:28:24 -0500 Subject: [PATCH 6/7] leave a comment to explain --- langdspy/prompt_runners.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 9eab540..de96481 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -193,6 +193,8 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna parsed_output[attr_name] = transformed_val end_time = time.time() + + # The model named may be called model_name or model depending on version of langchain_anthropic try: model_name = config["llm"].model_name except AttributeError: From 1b8979571b275683a9de1e14797a6d916dc76e9f Mon Sep 17 00:00:00 2001 From: blake messer Date: Wed, 20 Mar 2024 21:28:57 -0500 Subject: [PATCH 7/7] forgot the most important part --- langdspy/prompt_runners.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index de96481..1bfb6f9 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -200,7 +200,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna except AttributeError: model_name = config["llm"].model - self.prompt_history.add_entry(config["llm"].model_name, formatted_prompt, res, parsed_output, validation_err, start_time, end_time) + self.prompt_history.add_entry(model_name, formatted_prompt, res, parsed_output, validation_err, start_time, end_time) res = {attr_name: parsed_output.get(attr_name, None) for attr_name in self.template.output_variables.keys()}