diff --git a/.gitignore b/.gitignore index 35a44d2..7bcee1b 100644 --- a/.gitignore +++ b/.gitignore @@ -15,4 +15,4 @@ __pycache__/ .pytest_cache/ htmlcov/ .coverage - +*.swp diff --git a/langdspy/model.py b/langdspy/model.py index 55d01dc..0a5de52 100644 --- a/langdspy/model.py +++ b/langdspy/model.py @@ -92,7 +92,7 @@ def get_failed_prompts(self): failed_prompts = [] prompt_history = self.get_prompt_history() for runner_name, entry in prompt_history: - if not entry["success"]: + if entry["error"] is not None: failed_prompts.append((runner_name, entry)) return failed_prompts @@ -100,7 +100,7 @@ def get_successful_prompts(self): successful_prompts = [] prompt_history = self.get_prompt_history() for runner_name, entry in prompt_history: - if entry["success"]: + if entry["error"] is None: successful_prompts.append((runner_name, entry)) return successful_prompts @@ -156,4 +156,4 @@ def evaluate_subset(subset): logger.debug(f"Best score: {best_score} with subset: {best_subset}") self.trained_state.examples = best_subset - return self \ No newline at end of file + return self diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 55faa91..fd1534e 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -39,13 +39,15 @@ def __init__(self, **kwargs): class PromptHistory(BaseModel): history: List[Dict[str, Any]] = Field(default_factory=list) - def add_entry(self, prompt, llm_response, parsed_output, success, timestamp): + def add_entry(self, llm, prompt, llm_response, parsed_output, error, start_time, end_time): self.history.append({ - "prompt": prompt, + "duration_ms": round((end_time - start_time) * 1000), + "llm": llm, "llm_response": llm_response, "parsed_output": parsed_output, - "success": success, - "timestamp": timestamp + "prompt": prompt, + "error": error, + "timestamp": end_time, }) @@ -84,6 +86,18 @@ def _determine_llm_type(self, llm): else: return 'openai' # Default to OpenAI if model type cannot be determined + def _determine_llm_model(self, llm): + if isinstance(llm, ChatOpenAI): # Assuming OpenAILLM is the class for OpenAI models + return llm.model_name + elif isinstance(llm, ChatAnthropic): # Assuming AnthropicLLM is the class for Anthropic models + return llm.model + elif hasattr(llm, 'model_name'): + return llm.model_name + elif hasattr(llm, 'model'): + return llm.model + else: + return '???' + def get_prompt_history(self): return self.prompt_history.history @@ -101,6 +115,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna formatted_prompt = None while max_tries >= 1: + start_time = time.time() try: kwargs = {**self.model_kwargs, **self.kwargs} # logger.debug(f"PromptRunner invoke with input {input} and kwargs {kwargs} and config {config}") @@ -137,7 +152,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna continue - validation = True + validation_err = None # logger.debug(f"Raw output for prompt runner {self.template.__class__.__name__}: {res}") if print_prompt: @@ -150,8 +165,8 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna except Exception as e: import traceback traceback.print_exc() - logger.error(f"Failed to parse output for prompt runner {self.template.__class__.__name__}") - validation = False + validation_err = f"Failed to parse output for prompt runner {self.template.__class__.__name__}" + logger.error(validation_err) # logger.debug(f"Parsed output: {parsed_output}") len_parsed_output = len(parsed_output.keys()) @@ -159,24 +174,22 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna # logger.debug(f"Parsed output keys: {parsed_output.keys()} [{len_parsed_output}] Expected output keys: {self.template.output_variables.keys()} [{len_output_variables}]") if len(parsed_output.keys()) != len(self.template.output_variables.keys()): - logger.error(f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}") - validation = False + validation_err = f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}" + logger.error(validation_err) - self.prompt_history.add_entry(formatted_prompt, res, parsed_output, validation, time.time()) - - if validation: + if validation_err is None: # Transform and validate the outputs for attr_name, output_field in self.template.output_variables.items(): output_value = parsed_output.get(attr_name) if not output_value: - logger.error(f"Failed to get output value for field {attr_name} for prompt runner {self.template.__class__.__name__}") - validation = False + validation_err = f"Failed to get output value for field {attr_name} for prompt runner {self.template.__class__.__name__}" + logger.error(validation_err) continue # Validate the transformed value if not output_field.validate_value(input, output_value): - validation = False - logger.error(f"Failed to validate field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}") + validation_err = f"Failed to validate field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}" + logger.error(validation_err) # Get the transformed value try: @@ -184,16 +197,19 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna except Exception as e: import traceback traceback.print_exc() - logger.error(f"Failed to transform field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}") - validation = False + validation_err = f"Failed to transform field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}" + logger.error(validation_err) continue # Update the output with the transformed value parsed_output[attr_name] = transformed_val + end_time = time.time() + self.prompt_history.add_entry(self._determine_llm_type(config['llm']) + " " + self._determine_llm_model(config['llm']), formatted_prompt, res, parsed_output, validation_err, start_time, end_time) + res = {attr_name: parsed_output.get(attr_name, None) for attr_name in self.template.output_variables.keys()} - if validation: + if validation_err is None: return res logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry") @@ -263,4 +279,4 @@ def run_task(): # logger.debug(f"MultiPromptRunner predictions: {self.predictions}") - return predictions \ No newline at end of file + return predictions diff --git a/tests/test_prompt_history.py b/tests/test_prompt_history.py index 47e16fc..8e8330f 100644 --- a/tests/test_prompt_history.py +++ b/tests/test_prompt_history.py @@ -13,6 +13,8 @@ class TestPromptSignature2(PromptSignature): @pytest.fixture def llm(): class FakeLLM: + model = "test one" + def __call__(self, prompt, stop=None): return "Fake LLM response" return FakeLLM() @@ -43,17 +45,23 @@ def invoke(self, input_dict, config): runner_name1, entry1 = prompt_history[0] assert runner_name1 == "prompt_runner1" assert "prompt" in entry1 + assert "llm" in entry1 assert "llm_response" in entry1 + assert entry1["llm"] == "openai test one" assert entry1["parsed_output"] == {"output1": "Fake LLM response"} - assert entry1["success"] == True + assert entry1["error"] is None + assert "duration_ms" in entry1 assert "timestamp" in entry1 runner_name2, entry2 = prompt_history[1] assert runner_name2 == "prompt_runner2" assert "prompt" in entry2 + assert "llm" in entry2 assert "llm_response" in entry2 + assert entry2["llm"] == "openai test one" assert entry2["parsed_output"] == {"output2": "Fake LLM response"} - assert entry2["success"] == True + assert entry2["error"] is None + assert "duration_ms" in entry2 assert "timestamp" in entry2 def test_failed_prompts(llm): @@ -81,8 +89,8 @@ def invoke(self, input_dict, config): runner_name1, entry1 = successful_prompts[0] assert runner_name1 == "prompt_runner1" - assert entry1["success"] == True + assert entry1["error"] is None runner_name2, entry2 = successful_prompts[1] assert runner_name2 == "prompt_runner2" - assert entry2["success"] == True \ No newline at end of file + assert entry2["error"] is None