Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more info to the prompt history #22

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ __pycache__/
.pytest_cache/
htmlcov/
.coverage

*.swp
6 changes: 3 additions & 3 deletions langdspy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@ def get_failed_prompts(self):
failed_prompts = []
prompt_history = self.get_prompt_history()
for runner_name, entry in prompt_history:
if not entry["success"]:
if entry["error"] is not None:
failed_prompts.append((runner_name, entry))
return failed_prompts

def get_successful_prompts(self):
successful_prompts = []
prompt_history = self.get_prompt_history()
for runner_name, entry in prompt_history:
if entry["success"]:
if entry["error"] is None:
successful_prompts.append((runner_name, entry))
return successful_prompts

Expand Down Expand Up @@ -156,4 +156,4 @@ def evaluate_subset(subset):
logger.debug(f"Best score: {best_score} with subset: {best_subset}")

self.trained_state.examples = best_subset
return self
return self
56 changes: 36 additions & 20 deletions langdspy/prompt_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def __init__(self, **kwargs):
class PromptHistory(BaseModel):
history: List[Dict[str, Any]] = Field(default_factory=list)

def add_entry(self, prompt, llm_response, parsed_output, success, timestamp):
def add_entry(self, llm, prompt, llm_response, parsed_output, error, start_time, end_time):
self.history.append({
"prompt": prompt,
"duration_ms": round((end_time - start_time) * 1000),
"llm": llm,
"llm_response": llm_response,
"parsed_output": parsed_output,
"success": success,
"timestamp": timestamp
"prompt": prompt,
"error": error,
"timestamp": end_time,
})


Expand Down Expand Up @@ -84,6 +86,18 @@ def _determine_llm_type(self, llm):
else:
return 'openai' # Default to OpenAI if model type cannot be determined

def _determine_llm_model(self, llm):
if isinstance(llm, ChatOpenAI): # Assuming OpenAILLM is the class for OpenAI models
return llm.model_name
elif isinstance(llm, ChatAnthropic): # Assuming AnthropicLLM is the class for Anthropic models
return llm.model
elif hasattr(llm, 'model_name'):
return llm.model_name
elif hasattr(llm, 'model'):
return llm.model
else:
return '???'

def get_prompt_history(self):
return self.prompt_history.history

Expand All @@ -101,6 +115,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
formatted_prompt = None

while max_tries >= 1:
start_time = time.time()
try:
kwargs = {**self.model_kwargs, **self.kwargs}
# logger.debug(f"PromptRunner invoke with input {input} and kwargs {kwargs} and config {config}")
Expand Down Expand Up @@ -137,7 +152,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
continue


validation = True
validation_err = None

# logger.debug(f"Raw output for prompt runner {self.template.__class__.__name__}: {res}")
if print_prompt:
Expand All @@ -150,50 +165,51 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Failed to parse output for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Failed to parse output for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
# logger.debug(f"Parsed output: {parsed_output}")

len_parsed_output = len(parsed_output.keys())
len_output_variables = len(self.template.output_variables.keys())
# logger.debug(f"Parsed output keys: {parsed_output.keys()} [{len_parsed_output}] Expected output keys: {self.template.output_variables.keys()} [{len_output_variables}]")

if len(parsed_output.keys()) != len(self.template.output_variables.keys()):
logger.error(f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)

self.prompt_history.add_entry(formatted_prompt, res, parsed_output, validation, time.time())

if validation:
if validation_err is None:
# Transform and validate the outputs
for attr_name, output_field in self.template.output_variables.items():
output_value = parsed_output.get(attr_name)
if not output_value:
logger.error(f"Failed to get output value for field {attr_name} for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Failed to get output value for field {attr_name} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
continue

# Validate the transformed value
if not output_field.validate_value(input, output_value):
validation = False
logger.error(f"Failed to validate field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}")
validation_err = f"Failed to validate field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)

# Get the transformed value
try:
transformed_val = output_field.transform_value(output_value)
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Failed to transform field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Failed to transform field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
continue

# Update the output with the transformed value
parsed_output[attr_name] = transformed_val

end_time = time.time()
self.prompt_history.add_entry(self._determine_llm_type(config['llm']) + " " + self._determine_llm_model(config['llm']), formatted_prompt, res, parsed_output, validation_err, start_time, end_time)

res = {attr_name: parsed_output.get(attr_name, None) for attr_name in self.template.output_variables.keys()}

if validation:
if validation_err is None:
return res

logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry")
Expand Down Expand Up @@ -263,4 +279,4 @@ def run_task():


# logger.debug(f"MultiPromptRunner predictions: {self.predictions}")
return predictions
return predictions
16 changes: 12 additions & 4 deletions tests/test_prompt_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class TestPromptSignature2(PromptSignature):
@pytest.fixture
def llm():
class FakeLLM:
model = "test one"

def __call__(self, prompt, stop=None):
return "Fake LLM response"
return FakeLLM()
Expand Down Expand Up @@ -43,17 +45,23 @@ def invoke(self, input_dict, config):
runner_name1, entry1 = prompt_history[0]
assert runner_name1 == "prompt_runner1"
assert "prompt" in entry1
assert "llm" in entry1
assert "llm_response" in entry1
assert entry1["llm"] == "openai test one"
assert entry1["parsed_output"] == {"output1": "Fake LLM response"}
assert entry1["success"] == True
assert entry1["error"] is None
assert "duration_ms" in entry1
assert "timestamp" in entry1

runner_name2, entry2 = prompt_history[1]
assert runner_name2 == "prompt_runner2"
assert "prompt" in entry2
assert "llm" in entry2
assert "llm_response" in entry2
assert entry2["llm"] == "openai test one"
assert entry2["parsed_output"] == {"output2": "Fake LLM response"}
assert entry2["success"] == True
assert entry2["error"] is None
assert "duration_ms" in entry2
assert "timestamp" in entry2

def test_failed_prompts(llm):
Expand Down Expand Up @@ -81,8 +89,8 @@ def invoke(self, input_dict, config):

runner_name1, entry1 = successful_prompts[0]
assert runner_name1 == "prompt_runner1"
assert entry1["success"] == True
assert entry1["error"] is None

runner_name2, entry2 = successful_prompts[1]
assert runner_name2 == "prompt_runner2"
assert entry2["success"] == True
assert entry2["error"] is None
Loading