Skip to content

Commit

Permalink
add more info to the prompt history
Browse files Browse the repository at this point in the history
  • Loading branch information
jonasi committed Mar 21, 2024
1 parent 81d31fe commit edb9cd4
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ __pycache__/
.pytest_cache/
htmlcov/
.coverage

*.swp
6 changes: 3 additions & 3 deletions langdspy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,15 +92,15 @@ def get_failed_prompts(self):
failed_prompts = []
prompt_history = self.get_prompt_history()
for runner_name, entry in prompt_history:
if not entry["success"]:
if entry["error"] is not None:
failed_prompts.append((runner_name, entry))
return failed_prompts

def get_successful_prompts(self):
successful_prompts = []
prompt_history = self.get_prompt_history()
for runner_name, entry in prompt_history:
if entry["success"]:
if entry["error"] is None:
successful_prompts.append((runner_name, entry))
return successful_prompts

Expand Down Expand Up @@ -156,4 +156,4 @@ def evaluate_subset(subset):
logger.debug(f"Best score: {best_score} with subset: {best_subset}")

self.trained_state.examples = best_subset
return self
return self
56 changes: 36 additions & 20 deletions langdspy/prompt_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def __init__(self, **kwargs):
class PromptHistory(BaseModel):
history: List[Dict[str, Any]] = Field(default_factory=list)

def add_entry(self, prompt, llm_response, parsed_output, success, timestamp):
def add_entry(self, llm, prompt, llm_response, parsed_output, error, start_time, end_time):
self.history.append({
"prompt": prompt,
"duration_ms": round((end_time - start_time) * 1000),
"llm": llm,
"llm_response": llm_response,
"parsed_output": parsed_output,
"success": success,
"timestamp": timestamp
"prompt": prompt,
"error": error,
"timestamp": end_time,
})


Expand Down Expand Up @@ -84,6 +86,18 @@ def _determine_llm_type(self, llm):
else:
return 'openai' # Default to OpenAI if model type cannot be determined

def _determine_llm_model(self, llm):
if isinstance(llm, ChatOpenAI): # Assuming OpenAILLM is the class for OpenAI models
return llm.model_name
elif isinstance(llm, ChatAnthropic): # Assuming AnthropicLLM is the class for Anthropic models
return llm.model
elif hasattr(llm, 'model_name'):
return llm.model_name
elif hasattr(llm, 'model'):
return llm.model
else:
return '???'

def get_prompt_history(self):
return self.prompt_history.history

Expand All @@ -101,6 +115,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
formatted_prompt = None

while max_tries >= 1:
start_time = time.time()
try:
kwargs = {**self.model_kwargs, **self.kwargs}
# logger.debug(f"PromptRunner invoke with input {input} and kwargs {kwargs} and config {config}")
Expand Down Expand Up @@ -137,7 +152,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
continue


validation = True
validation_err = None

# logger.debug(f"Raw output for prompt runner {self.template.__class__.__name__}: {res}")
if print_prompt:
Expand All @@ -150,50 +165,51 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Failed to parse output for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Failed to parse output for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
# logger.debug(f"Parsed output: {parsed_output}")

len_parsed_output = len(parsed_output.keys())
len_output_variables = len(self.template.output_variables.keys())
# logger.debug(f"Parsed output keys: {parsed_output.keys()} [{len_parsed_output}] Expected output keys: {self.template.output_variables.keys()} [{len_output_variables}]")

if len(parsed_output.keys()) != len(self.template.output_variables.keys()):
logger.error(f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)

self.prompt_history.add_entry(formatted_prompt, res, parsed_output, validation, time.time())

if validation:
if validation_err is None:
# Transform and validate the outputs
for attr_name, output_field in self.template.output_variables.items():
output_value = parsed_output.get(attr_name)
if not output_value:
logger.error(f"Failed to get output value for field {attr_name} for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Failed to get output value for field {attr_name} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
continue

# Validate the transformed value
if not output_field.validate_value(input, output_value):
validation = False
logger.error(f"Failed to validate field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}")
validation_err = f"Failed to validate field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)

# Get the transformed value
try:
transformed_val = output_field.transform_value(output_value)
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Failed to transform field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Failed to transform field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
continue

# Update the output with the transformed value
parsed_output[attr_name] = transformed_val

end_time = time.time()
self.prompt_history.add_entry(self._determine_llm_type(config['llm']) + " " + self._determine_llm_model(config['llm']), formatted_prompt, res, parsed_output, validation_err, start_time, end_time)

res = {attr_name: parsed_output.get(attr_name, None) for attr_name in self.template.output_variables.keys()}

if validation:
if validation_err is None:
return res

logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry")
Expand Down Expand Up @@ -263,4 +279,4 @@ def run_task():


# logger.debug(f"MultiPromptRunner predictions: {self.predictions}")
return predictions
return predictions
16 changes: 12 additions & 4 deletions tests/test_prompt_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class TestPromptSignature2(PromptSignature):
@pytest.fixture
def llm():
class FakeLLM:
model = "test one"

def __call__(self, prompt, stop=None):
return "Fake LLM response"
return FakeLLM()
Expand Down Expand Up @@ -43,17 +45,23 @@ def invoke(self, input_dict, config):
runner_name1, entry1 = prompt_history[0]
assert runner_name1 == "prompt_runner1"
assert "prompt" in entry1
assert "llm" in entry1
assert "llm_response" in entry1
assert entry1["llm"] == "openai test one"
assert entry1["parsed_output"] == {"output1": "Fake LLM response"}
assert entry1["success"] == True
assert entry1["error"] is None
assert "duration_ms" in entry1
assert "timestamp" in entry1

runner_name2, entry2 = prompt_history[1]
assert runner_name2 == "prompt_runner2"
assert "prompt" in entry2
assert "llm" in entry2
assert "llm_response" in entry2
assert entry2["llm"] == "openai test one"
assert entry2["parsed_output"] == {"output2": "Fake LLM response"}
assert entry2["success"] == True
assert entry2["error"] is None
assert "duration_ms" in entry2
assert "timestamp" in entry2

def test_failed_prompts(llm):
Expand Down Expand Up @@ -81,8 +89,8 @@ def invoke(self, input_dict, config):

runner_name1, entry1 = successful_prompts[0]
assert runner_name1 == "prompt_runner1"
assert entry1["success"] == True
assert entry1["error"] is None

runner_name2, entry2 = successful_prompts[1]
assert runner_name2 == "prompt_runner2"
assert entry2["success"] == True
assert entry2["error"] is None

0 comments on commit edb9cd4

Please sign in to comment.