-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- OpenAI and Anthropic pass expensive unit tests - Mistral still gets stuck in a loop when using tools - Cheap unit tests still broken
- Loading branch information
Showing
5 changed files
with
135 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,116 @@ | ||
from . import BaseModel | ||
from .. import config | ||
from .. import config, utils | ||
import logging | ||
logger = logging.getLogger('heymans') | ||
|
||
|
||
class AnthropicModel(BaseModel): | ||
|
||
max_retry = 3 | ||
|
||
def __init__(self, heymans, model): | ||
from langchain_anthropic import ChatAnthropic | ||
super().__init__(heymans) | ||
self._model = ChatAnthropic( | ||
model=model, anthropic_api_key=config.anthropic_api_key) | ||
def __init__(self, heymans, model, **kwargs): | ||
from anthropic import Anthropic, AsyncAnthropic | ||
super().__init__(heymans, **kwargs) | ||
self._model = model | ||
self._tool_use_id = 0 | ||
self._client = Anthropic(api_key=config.anthropic_api_key) | ||
self._async_client = AsyncAnthropic(api_key=config.anthropic_api_key) | ||
|
||
def predict(self, messages): | ||
if isinstance(messages, list): | ||
messages = utils.prepare_messages(messages, allow_ai_first=False, | ||
allow_ai_last=False, | ||
merge_consecutive=True) | ||
# Claude seems to crash occasionally, in which case a retry will do the | ||
# trick | ||
for i in range(self.max_retry): | ||
try: | ||
return super().predict(messages) | ||
except Exception as e: | ||
logger.warning(f'error in prediction (retrying): {e}') | ||
if isinstance(messages, str): | ||
return super().predict([self.convert_message(messages)]) | ||
messages = utils.prepare_messages(messages, allow_ai_first=False, | ||
allow_ai_last=False, | ||
merge_consecutive=True) | ||
messages = [self.convert_message(message) for message in messages] | ||
# The Anthropic messages API doesn't accept tool results in a separate | ||
# message. Instead, tool results are included as a special content | ||
# block in a user message. Since two subsequent user messages aren't | ||
# allowed, we need to convert a tool message to a user message and if | ||
# necessary merge it with the next user message. | ||
while True: | ||
logger.info('entering message postprocessing loop') | ||
for i, message in enumerate(messages): | ||
if message['role'] == 'tool': | ||
logger.info('converting tool message to user message') | ||
message['role'] = 'user' | ||
message['content'] = [{ | ||
'type': 'tool_result', | ||
'tool_use_id': str(self._tool_use_id), | ||
'content': [{ | ||
'type': 'text', | ||
'text': message['content'] | ||
}] | ||
}] | ||
if i > 0: | ||
# The previous message needs to have a tool-use block | ||
prev_message = messages[i - 1] | ||
prev_message['content'] = [ | ||
{'type': 'text', | ||
'text': prev_message['content']}, | ||
{'type': 'tool_use', | ||
'id': str(self._tool_use_id), | ||
'input': {'args': 'input args'}, | ||
'name': 'tool_function' | ||
} | ||
] | ||
self._tool_use_id += 1 | ||
if len(messages) > i + 1: | ||
logger.info('merging tool and user message') | ||
next_message = messages[i + 1] | ||
if next_message['role'] == 'user': | ||
message['content'].append([{ | ||
"type": "text", | ||
"text": next_message['content'] | ||
}]) | ||
break | ||
else: | ||
break | ||
logger.info('dropping duplicate user message') | ||
messages.remove(next_message) | ||
return super().predict(messages) | ||
|
||
def get_response(self, response): | ||
text = [] | ||
for block in response.content: | ||
if block.type == 'tool_use': | ||
for tool in self._tools: | ||
if tool.name == block.name: | ||
return tool.bind(block.input) | ||
return self.invalid_tool | ||
if block.type == 'text': | ||
text.append(block.text) | ||
return '\n'.join(text) | ||
|
||
|
||
def _tool_args(self): | ||
if not self._tools: | ||
return {} | ||
alternative_format_tools = [] | ||
for tool in self.tools(): | ||
if tool['type'] == 'function': | ||
function = tool['function'] | ||
alt_tool = { | ||
"name": function['name'], | ||
"description": function['description'], | ||
"input_schema": function['parameters'] | ||
} | ||
alternative_format_tools.append(alt_tool) | ||
return {'tools': alternative_format_tools} | ||
|
||
def _anthropic_invoke(self, fnc, messages): | ||
kwargs = self._tool_args() | ||
# If the first message is the system prompt, we need to separate this | ||
# from the user and assistant messages, because the Anthropic messages | ||
# API takes this as a separate keyword argument | ||
if messages[0]['role'] == 'system': | ||
kwargs['system'] = messages[0]['content'] | ||
messages = messages[1:] | ||
return fnc(model=self._model, max_tokens=config.anthropic_max_tokens, | ||
messages=messages, **kwargs) | ||
|
||
def invoke(self, messages): | ||
return self._anthropic_invoke( | ||
self._client.beta.tools.messages.create, messages) | ||
|
||
def async_invoke(self, messages): | ||
return self._anthropic_invoke( | ||
self._async_client.beta.tools.messages.create, messages) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters