From 3be13e7a0865b9d5b2b6f7db9e75b7497c786de6 Mon Sep 17 00:00:00 2001 From: Isao Jonas Date: Wed, 20 Mar 2024 15:18:20 -0500 Subject: [PATCH] be more defensive about model type check --- langdspy/prompt_runners.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/langdspy/prompt_runners.py b/langdspy/prompt_runners.py index 1110f47..7000f6f 100644 --- a/langdspy/prompt_runners.py +++ b/langdspy/prompt_runners.py @@ -86,6 +86,18 @@ def _determine_llm_type(self, llm): else: return 'openai' # Default to OpenAI if model type cannot be determined + def _determine_llm_model(self, llm): + if isinstance(llm, ChatOpenAI): # Assuming OpenAILLM is the class for OpenAI models + return llm.model_name + elif isinstance(llm, ChatAnthropic): # Assuming AnthropicLLM is the class for Anthropic models + return llm.model + elif hasattr(llm, 'model_name'): + return llm.model_name + elif hasattr(llm, 'model'): + return llm.model + else: + return '???' + def get_prompt_history(self): return self.prompt_history.history @@ -193,7 +205,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna parsed_output[attr_name] = transformed_val end_time = time.time() - self.prompt_history.add_entry(config["llm"].model_name, formatted_prompt, res, parsed_output, validation_err, start_time, end_time) + self.prompt_history.add_entry(self._determine_llm_type() + " " + self._determine_llm_model(), formatted_prompt, res, parsed_output, validation_err, start_time, end_time) res = {attr_name: parsed_output.get(attr_name, None) for attr_name in self.template.output_variables.keys()}