Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't choke on new (or old) versions of langchain_anthropic #19

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ __pycache__/
.pytest_cache/
htmlcov/
.coverage

*.swp
51 changes: 31 additions & 20 deletions langdspy/prompt_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ def __init__(self, **kwargs):
class PromptHistory(BaseModel):
history: List[Dict[str, Any]] = Field(default_factory=list)

def add_entry(self, prompt, llm_response, parsed_output, success, timestamp):
def add_entry(self, llm, prompt, llm_response, parsed_output, error, start_time, end_time):
self.history.append({
"prompt": prompt,
"duration_ms": round((end_time - start_time) * 1000),
"llm": llm,
"llm_response": llm_response,
"parsed_output": parsed_output,
"success": success,
"timestamp": timestamp
"prompt": prompt,
"error": error,
"timestamp": end_time,
})


Expand Down Expand Up @@ -101,6 +103,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
formatted_prompt = None

while max_tries >= 1:
start_time = time.time()
try:
kwargs = {**self.model_kwargs, **self.kwargs}
# logger.debug(f"PromptRunner invoke with input {input} and kwargs {kwargs} and config {config}")
Expand Down Expand Up @@ -137,7 +140,7 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
continue


validation = True
validation_err = None

# logger.debug(f"Raw output for prompt runner {self.template.__class__.__name__}: {res}")
if print_prompt:
Expand All @@ -150,50 +153,58 @@ def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[Runna
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Failed to parse output for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Failed to parse output for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
# logger.debug(f"Parsed output: {parsed_output}")

len_parsed_output = len(parsed_output.keys())
len_output_variables = len(self.template.output_variables.keys())
# logger.debug(f"Parsed output keys: {parsed_output.keys()} [{len_parsed_output}] Expected output keys: {self.template.output_variables.keys()} [{len_output_variables}]")

if len(parsed_output.keys()) != len(self.template.output_variables.keys()):
logger.error(f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Output keys do not match expected output keys for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)

self.prompt_history.add_entry(formatted_prompt, res, parsed_output, validation, time.time())

if validation:
if validation_err is None:
# Transform and validate the outputs
for attr_name, output_field in self.template.output_variables.items():
output_value = parsed_output.get(attr_name)
if not output_value:
logger.error(f"Failed to get output value for field {attr_name} for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Failed to get output value for field {attr_name} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
continue

# Validate the transformed value
if not output_field.validate_value(input, output_value):
validation = False
logger.error(f"Failed to validate field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}")
validation_err = f"Failed to validate field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)

# Get the transformed value
try:
transformed_val = output_field.transform_value(output_value)
except Exception as e:
import traceback
traceback.print_exc()
logger.error(f"Failed to transform field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}")
validation = False
validation_err = f"Failed to transform field {attr_name} value {output_value} for prompt runner {self.template.__class__.__name__}"
logger.error(validation_err)
continue

# Update the output with the transformed value
parsed_output[attr_name] = transformed_val

end_time = time.time()

# The model named may be called model_name or model depending on version of langchain_anthropic
try:
model_name = config["llm"].model_name
except AttributeError:
model_name = config["llm"].model

self.prompt_history.add_entry(model_name, formatted_prompt, res, parsed_output, validation_err, start_time, end_time)

res = {attr_name: parsed_output.get(attr_name, None) for attr_name in self.template.output_variables.keys()}

if validation:
if validation_err is None:
return res

logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry")
Expand Down Expand Up @@ -263,4 +274,4 @@ def run_task():


# logger.debug(f"MultiPromptRunner predictions: {self.predictions}")
return predictions
return predictions
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ joblib = "1.3.2"
jsonpatch = "1.33"
jsonpointer = "2.4"
keyring = "24.3.1"
langchain = "0.1.7"
langchain-anthropic = "0.1.3"
langchain-community = "0.0.20"
langchain-core = "0.1.23"
langchain-openai = "0.0.6"
langchain = "^0.1.12"
langchain-anthropic = "^0.1.3"
langchain-community = "^0.0.28"
langchain-core = "^0.1.32"
langchain-openai = "^0.0.6"
markdown-it-py = "3.0.0"
marshmallow = "3.20.2"
mdurl = "0.1.2"
Expand Down Expand Up @@ -107,4 +107,4 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry.scripts]
test = "scripts:test"
coverage = "scripts:coverage"
coverage = "scripts:coverage"
Loading