Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prompt caching anthropic #29

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
98f74f7
feat: Add Anthropic chat model integration
fernandocratejoy Aug 30, 2024
fd9fb3a
feat: Modify Anthropic prompt formatting in PromptRunner
fernandocratejoy Aug 30, 2024
8843084
feat: Update prompt_runners.py to simplify Anthropic LLM invocation
fernandocratejoy Aug 30, 2024
65d5e25
fix: Add error handling to PromptRunner class
fernandocratejoy Aug 30, 2024
5bc76e3
fix: Improve error handling in _invoke_with_retries method
fernandocratejoy Aug 30, 2024
8161735
fix: Handle 'openai_json' case in _execute_prompt method
fernandocratejoy Aug 30, 2024
ca4435c
fix: Simplify condition for handling 'anthropic' and 'openai_json' ca…
fernandocratejoy Aug 30, 2024
000c2dd
fix: remove config parameter from chain.invoke calls
fernandocratejoy Aug 30, 2024
26eb0a9
feat: Add logging and error handling to PromptRunner.invoke method
fernandocratejoy Aug 30, 2024
be13531
fix: Add detailed error logging in _execute_prompt method
fernandocratejoy Aug 30, 2024
8bbcc33
fix: Add error handling and logging to PromptRunner.invoke
fernandocratejoy Aug 30, 2024
d11d530
fix: Update invoke method in PromptRunner class
fernandocratejoy Aug 30, 2024
a0298bf
feat: Update PromptRunner to use current LangChain API
fernandocratejoy Aug 30, 2024
315c215
fix: Handle LLM configuration in PromptRunner.invoke
fernandocratejoy Aug 30, 2024
0b86126
feat: Update PromptRunner invoke method to use current LangChain API
fernandocratejoy Aug 30, 2024
d88cf1d
feat: Refactor prompt strategy to use SystemMessage and HumanMessage
fernandocratejoy Aug 30, 2024
1764355
fix: Add input validation to PromptSignature class
fernandocratejoy Aug 30, 2024
58e2341
fix: Handle optional input fields in PromptSignature._validate_input
fernandocratejoy Aug 30, 2024
9ae5cdc
fix: Handle empty input variables in PromptSignature
fernandocratejoy Aug 30, 2024
2c10a07
feat: Add PromptGenerateState and GenerateState classes to langdspy/p…
fernandocratejoy Aug 30, 2024
2359a44
fix: Handle list input in PromptRunner.invoke
fernandocratejoy Aug 30, 2024
fbe6f8a
fix: Handle different types of formatted_prompt in PromptRunner.invoke
fernandocratejoy Aug 30, 2024
c0d4823
fix: update `PromptRunner` to handle `formatted_prompt` correctly
fernandocratejoy Aug 30, 2024
0db5a90
feat: Refactor prompt runner to handle list of messages and improve i…
fernandocratejoy Aug 30, 2024
1261890
fix: Handle AIMessage object returned by LLM
fernandocratejoy Aug 30, 2024
541d47f
fix: Remove unused imports and code in prompt_strategies.py
fernandocratejoy Aug 30, 2024
40a451e
fix: Handle AIMessage objects in output parsing
fernandocratejoy Aug 30, 2024
8da94f4
feat: Add support for parsing dictionary output in DefaultPromptStrategy
fernandocratejoy Sep 2, 2024
6a132a8
fix: Ensure all output fields are present in parsed_fields
fernandocratejoy Sep 2, 2024
580a57e
Caching examples and input description for Anthropic
fernandocratejoy Sep 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 110 additions & 44 deletions langdspy/prompt_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
# from langchain_contrib.llms.testing import FakeLLM
from typing import Any, Dict, List, Type, Optional, Callable
from typing import Any, Dict, List, Type, Optional, Callable, Union
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional



Expand Down Expand Up @@ -59,17 +58,13 @@ def reset(self):

class PromptRunner(RunnableSerializable):
template: PromptSignature = None
# prompt_history: List[str] = [] - Was trying to find a way to make a list of prompts for inspection
model_kwargs: Dict[str, Any] = {}
kwargs: Dict[str, Any] = {}
prompt_history: PromptHistory = Field(default_factory=PromptHistory)


def __init__(self, template_class, prompt_strategy, **kwargs):
super().__init__()

self.kwargs = kwargs

cls_ = type(template_class.__name__, (prompt_strategy, template_class), {})
self.template = cls_()

Expand All @@ -78,70 +73,73 @@ def check_template(
cls, value: PromptSignature
) -> PromptSignature:
return value


def set_model_kwargs(self, model_kwargs):
self.model_kwargs.update(model_kwargs)

def _determine_llm_type(self, llm):
if isinstance(llm, ChatOpenAI): # Assuming OpenAILLM is the class for OpenAI models
def _determine_llm_type(self, llm: Union[ChatOpenAI, ChatAnthropic]) -> str:
if isinstance(llm, ChatOpenAI):
kwargs = getattr(llm, 'kwargs', None) or getattr(llm, 'model_kwargs', {})
logger.debug(kwargs)
if kwargs.get('response_format', {}).get('type') == 'json_object':
logger.info("OpenAI model response format is json_object")
return 'openai_json'
return 'openai'
elif isinstance(llm, ChatAnthropic): # Assuming AnthropicLLM is the class for Anthropic models
elif isinstance(llm, ChatAnthropic):
return 'anthropic'
else:
return 'openai' # Default to OpenAI if model type cannot be determined
return 'unknown'

def _determine_llm_model(self, llm):
if isinstance(llm, ChatOpenAI): # Assuming OpenAILLM is the class for OpenAI models
def _determine_llm_model(self, llm: Union[ChatOpenAI, ChatAnthropic]) -> str:
if isinstance(llm, ChatOpenAI):
return llm.model_name
elif isinstance(llm, ChatAnthropic): # Assuming AnthropicLLM is the class for Anthropic models
elif isinstance(llm, ChatAnthropic):
return llm.model
elif hasattr(llm, 'model_name'):
return llm.model_name
elif hasattr(llm, 'model'):
return llm.model
else:
return '???'
return 'unknown'

def get_prompt_history(self):
return self.prompt_history.history

def clear_prompt_history(self):
self.prompt_history.reset()

def _invoke_with_retries(self, chain, input, max_tries=1, config: Optional[RunnableConfig] = {}):
def _invoke_with_retries(self, invoke_func, input, max_tries=1, config: Optional[RunnableConfig] = {}):
total_max_tries = max_tries
hard_fail = config.get('hard_fail', True)
llm_type, llm_model = self._get_llm_info(config)

logger.debug(f"LLM type: {llm_type} - model {llm_model}")
prompt_res = None
last_error = None

while max_tries >= 1:
start_time = time.time()
try:
formatted_prompt, prompt_res = self._execute_prompt(chain, input, config, llm_type)
prompt_res = invoke_func()
parsed_output, validation_err = self._process_output(prompt_res, input, llm_type)

end_time = time.time()
self._log_prompt_history(config, formatted_prompt, prompt_res, parsed_output, validation_err, start_time, end_time)
self._log_prompt_history(config, self.template.format_prompt(**input, llm_type=llm_type), prompt_res, parsed_output, validation_err, start_time, end_time)

if validation_err is None:
return parsed_output
last_error = validation_err

except Exception as e:
logger.error(f"Error in _invoke_with_retries: {str(e)}")
self._handle_exception(e, max_tries)
last_error = str(e)

max_tries -= 1
if max_tries >= 1:
self._handle_retry(max_tries)

return self._handle_failure(hard_fail, total_max_tries, prompt_res)
return self._handle_failure(hard_fail, total_max_tries, prompt_res, last_error)

def _get_llm_info(self, config):
llm_type = config.get('llm_type') or self._determine_llm_type(config['llm'])
Expand All @@ -158,8 +156,22 @@ def _execute_prompt(self, chain, input, config, llm_type):
self._log_prompt(formatted_prompt)
logger.debug(f"CONFIG: {config}")

prompt_res = chain.invoke(invoke_args, config=config)
return formatted_prompt, prompt_res
try:
if llm_type in ['anthropic', 'openai', 'openai_json']:
if isinstance(formatted_prompt, list):
prompt_res = chain.invoke({"messages": formatted_prompt})
else:
prompt_res = chain.invoke(invoke_args, config=config)
else:
prompt_res = chain.invoke(invoke_args)
return formatted_prompt, prompt_res
except Exception as e:
logger.error(f"Error executing prompt: {str(e)}")
logger.error(f"Chain: {chain}")
logger.error(f"Input: {input}")
logger.error(f"Config: {config}")
logger.error(f"LLM Type: {llm_type}")
raise

def _get_trained_state(self, config):
trained_state = config.get('trained_state') or self.model_kwargs.get('trained_state') or self.kwargs.get('trained_state')
Expand Down Expand Up @@ -227,36 +239,90 @@ def _handle_retry(self, max_tries):
logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__}, pausing before we retry")
time.sleep(random.uniform(0.05, 0.25))

def _handle_failure(self, hard_fail, total_max_tries, prompt_res):
def _handle_failure(self, hard_fail, total_max_tries, prompt_res, last_error):
error_message = f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries. Last error: {last_error}"
if hard_fail:
raise ValueError(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries.")
raise ValueError(error_message)
else:
logger.error(f"Output validation failed for prompt runner {self.template.__class__.__name__} after {total_max_tries} tries, returning None.")
logger.error(error_message + " Returning default values.")
if len(self.template.output_variables.keys()) == 1:
return {attr_name: prompt_res for attr_name in self.template.output_variables.keys()}
else:
return {attr_name: None for attr_name in self.template.output_variables.keys()}

def invoke(self, input: Input, config: Optional[RunnableConfig] = {}) -> Output:
chain = (
self.template
| config['llm']
| StrOutputParser()
)

max_retries = config.get('max_tries', 3)

if '__examples__' in config:
input['__examples__'] = config['__examples__']

res = self._invoke_with_retries(chain, input, max_retries, config=config)

prediction_data = {**input, **res}


prediction = Prediction(**prediction_data)

return prediction
def invoke(self, input: Input, config: Optional[RunnableConfig] = None) -> Output:
try:
logger.debug(f"PromptRunner invoke called with input: {input}")
logger.debug(f"Config: {config}")

if config is None:
config = {}

llm = config.get('llm')
if llm is None:
raise ValueError("'llm' is not present in the config")

max_retries = config.get('max_tries', 3)

if '__examples__' in config:
input['__examples__'] = config['__examples__']

kwargs = {**self.model_kwargs, **self.kwargs, **input}
llm_type = self._determine_llm_type(llm)

formatted_prompt = self.template.format_prompt(**kwargs, llm_type=llm_type)

if isinstance(formatted_prompt, list):
# If formatted_prompt is a list, assume it's a list of messages
chain = lambda x: llm.invoke(formatted_prompt)
res = self._invoke_with_retries(
lambda: chain(input),
input,
max_retries,
config=config
)
# Extract content if it's an AIMessage
if hasattr(res, 'content'):
res = res.content
else:
chain = self.template | llm | StrOutputParser()


res = self._invoke_with_retries(
lambda : self._execute_prompt(chain, input, config, llm_type=llm_type)[1], \
input, \
max_retries, \
config=config)
'''
chain = formatted_prompt | llm | StrOutputParser()
res = self._invoke_with_retries(
lambda: llm.invoke(formatted_prompt, config=config).content,
input,
max_retries,
config=config
)
'''

logger.debug(f"Result from _invoke_with_retries: {res}")

parsed_output = self.template.parse_output_to_fields(res, llm_type)

print(f"parsed_output = {parsed_output}")

prediction_data = {**input, **parsed_output}
logger.debug(f"Prediction data: {prediction_data}")

prediction = Prediction(**prediction_data)
logger.debug(f"Final prediction: {prediction}")

return prediction
except Exception as e:
logger.error(f"Error in PromptRunner invoke: {str(e)}")
logger.error(f"Input: {input}")
logger.error(f"Config: {config}")
import traceback
logger.error(traceback.format_exc())
raise


class MultiPromptRunner(PromptRunner):
Expand Down
Loading