Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add more info to the prompt history #15

Merged
merged 1 commit into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ __pycache__/
.pytest_cache/
htmlcov/
.coverage

*.swp
44 changes: 24 additions & 20 deletions langdspy/prompt_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def __init__(self, **kwargs):
class PromptHistory(BaseModel):
history: List[Dict[str, Any]] = Field(default_factory=list)

def add_entry(self, prompt, llm_response, parsed_output, success, timestamp):
def add_entry(self, llm, prompt, llm_response, parsed_output, error, start_time, end_time):
self.history.append({
"prompt": prompt,
"duration_ms": round((end_time - start_time) * 1000),
"llm": llm,
"llm_response": llm_response,
"parsed_output": parsed_output,
"success": success,
"timestamp": timestamp
"prompt": prompt,
"error": error,
"timestamp": end_time,
})


Expand Down Expand Up @@ -101,6 +103,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
formatted_prompt = None

while max_tries >= 1:
start_time = time.time()
try:
kwargs = {**self.model_kwargs, **self.kwargs}
# logger.debug(f"PromptRunner invoke with input {input} and kwargs {kwargs} and config {config}")
Expand Down Expand Up @@ -137,7 +140,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
continue


validation = True
validation_err = None

# logger.debug(f"Raw output for prompt runner {self.template.__class__.__name__}: {res}")
if print_prompt:
Expand All @@ -150,50 +153,51 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Failed to parse output for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Failed to parse output for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
# logger.debug(f"Parsed output: {parsed_output}")

len_parsed_output = len(parsed_output.keys())
len_output_variables = len(self.template.output_variables.keys())
# logger.debug(f"Parsed output keys: {parsed_output.keys()} [{len_parsed_output}] Expected output keys: {self.template.output_variables.keys()} [{len_output_variables}]")

if len(parsed_output.keys()) != len(self.template.output_variables.keys()):
logger.error(f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)

self.prompt_history.add_entry(formatted_prompt, res, parsed_output, validation, time.time())

if validation:
if validation_err is None:
# Transform and validate the outputs
for attr_name, output_field in self.template.output_variables.items():
output_value = parsed_output.get(attr_name)
if not output_value:
logger.error(f"Failed to get output value for field {attr_name} for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Failed to get output value for field {attr_name} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
continue

# Validate the transformed value
if not output_field.validate_value(input, output_value):
validation = False
logger.error(f"Failed to validate field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}")
validation_err = f"Failed to validate field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)

# Get the transformed value
try:
transformed_val = output_field.transform_value(output_value)
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Failed to transform field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Failed to transform field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
continue

# Update the output with the transformed value
parsed_output[attr_name] = transformed_val

end_time = time.time()
self.prompt_history.add_entry(config["llm"].model_name, formatted_prompt, res, parsed_output, validation_err, start_time, end_time)

res = {attr_name: parsed_output.get(attr_name, None) for attr_name in self.template.output_variables.keys()}

if validation:
if validation_err is None:
return res

logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry")
Expand Down Expand Up @@ -263,4 +267,4 @@ def run_task():


# logger.debug(f"MultiPromptRunner predictions: {self.predictions}")
return predictions
return predictions
Loading