Skip to content

Commit

Permalink
Implement new Anthropic model
Browse files Browse the repository at this point in the history
- OpenAI and Anthropic pass expensive unit tests
- Mistral still gets stuck in a loop when using tools
- Cheap unit tests still broken
  • Loading branch information
smathot committed Apr 15, 2024
1 parent 456960c commit 7360b96
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 48 deletions.
6 changes: 4 additions & 2 deletions heymans/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@
'answer_model': 'gpt-4'
},
'anthropic': {
'search_model': 'mistral-medium',
'condense_model': 'mistral-medium',
'search_model': 'claude-3-sonnet',
'condense_model': 'claude-3-sonnet',
'answer_model': 'claude-3-opus'
},
'mistral': {
Expand All @@ -134,6 +134,8 @@
'answer_model': 'dummy'
}
}
# Model-specific arguments
anthropic_max_tokens = 1024

# TOOLS
#
Expand Down
126 changes: 107 additions & 19 deletions heymans/model/_anthropic_model.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,116 @@
from . import BaseModel
from .. import config
from .. import config, utils
import logging
logger = logging.getLogger('heymans')


class AnthropicModel(BaseModel):

max_retry = 3

def __init__(self, heymans, model):
from langchain_anthropic import ChatAnthropic
super().__init__(heymans)
self._model = ChatAnthropic(
model=model, anthropic_api_key=config.anthropic_api_key)
def __init__(self, heymans, model, **kwargs):
from anthropic import Anthropic, AsyncAnthropic
super().__init__(heymans, **kwargs)
self._model = model
self._tool_use_id = 0
self._client = Anthropic(api_key=config.anthropic_api_key)
self._async_client = AsyncAnthropic(api_key=config.anthropic_api_key)

def predict(self, messages):
if isinstance(messages, list):
messages = utils.prepare_messages(messages, allow_ai_first=False,
allow_ai_last=False,
merge_consecutive=True)
# Claude seems to crash occasionally, in which case a retry will do the
# trick
for i in range(self.max_retry):
try:
return super().predict(messages)
except Exception as e:
logger.warning(f'error in prediction (retrying): {e}')
if isinstance(messages, str):
return super().predict([self.convert_message(messages)])
messages = utils.prepare_messages(messages, allow_ai_first=False,
allow_ai_last=False,
merge_consecutive=True)
messages = [self.convert_message(message) for message in messages]
# The Anthropic messages API doesn't accept tool results in a separate
# message. Instead, tool results are included as a special content
# block in a user message. Since two subsequent user messages aren't
# allowed, we need to convert a tool message to a user message and if
# necessary merge it with the next user message.
while True:
logger.info('entering message postprocessing loop')
for i, message in enumerate(messages):
if message['role'] == 'tool':
logger.info('converting tool message to user message')
message['role'] = 'user'
message['content'] = [{
'type': 'tool_result',
'tool_use_id': str(self._tool_use_id),
'content': [{
'type': 'text',
'text': message['content']
}]
}]
if i > 0:
# The previous message needs to have a tool-use block
prev_message = messages[i - 1]
prev_message['content'] = [
{'type': 'text',
'text': prev_message['content']},
{'type': 'tool_use',
'id': str(self._tool_use_id),
'input': {'args': 'input args'},
'name': 'tool_function'
}
]
self._tool_use_id += 1
if len(messages) > i + 1:
logger.info('merging tool and user message')
next_message = messages[i + 1]
if next_message['role'] == 'user':
message['content'].append([{
"type": "text",
"text": next_message['content']
}])
break
else:
break
logger.info('dropping duplicate user message')
messages.remove(next_message)
return super().predict(messages)

def get_response(self, response):
text = []
for block in response.content:
if block.type == 'tool_use':
for tool in self._tools:
if tool.name == block.name:
return tool.bind(block.input)
return self.invalid_tool
if block.type == 'text':
text.append(block.text)
return '\n'.join(text)


def _tool_args(self):
if not self._tools:
return {}
alternative_format_tools = []
for tool in self.tools():
if tool['type'] == 'function':
function = tool['function']
alt_tool = {
"name": function['name'],
"description": function['description'],
"input_schema": function['parameters']
}
alternative_format_tools.append(alt_tool)
return {'tools': alternative_format_tools}

def _anthropic_invoke(self, fnc, messages):
kwargs = self._tool_args()
# If the first message is the system prompt, we need to separate this
# from the user and assistant messages, because the Anthropic messages
# API takes this as a separate keyword argument
if messages[0]['role'] == 'system':
kwargs['system'] = messages[0]['content']
messages = messages[1:]
return fnc(model=self._model, max_tokens=config.anthropic_max_tokens,
messages=messages, **kwargs)

def invoke(self, messages):
return self._anthropic_invoke(
self._client.beta.tools.messages.create, messages)

def async_invoke(self, messages):
return self._anthropic_invoke(
self._async_client.beta.tools.messages.create, messages)
28 changes: 23 additions & 5 deletions heymans/model/_base_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import asyncio
import time
from langchain.schema import SystemMessage, AIMessage, HumanMessage, \
FunctionMessage
logger = logging.getLogger('heymans')


Expand All @@ -17,7 +19,7 @@ def __init__(self, heymans, tools=None, tool_choice='auto'):
self.prompt_tokens_consumed = 0
self.completion_tokens_consumed = 0

def invalid_tool(self):
def invalid_tool(self) -> str:
return 'Invalid tool'

def get_response(self, response) -> [str, callable]:
Expand All @@ -28,16 +30,31 @@ def tools(self):
for t in self._tools if t.tool_spec]

def invoke(self, messages):
return self._model.invoke(messages)
raise NotImplementedError()

def async_invoke(self, messages):
return self._model.ainvoke(messages)
raise NotImplementedError()

def messages_length(self, messages):
def messages_length(self, messages) -> int:
if isinstance(messages, str):
return len(messages)
return lebase_format_toolsn(messages)
return sum([len(m.content if hasattr(m, 'content') else m['content'])
for m in messages])

def convert_message(self, message):
if isinstance(message, str):
return dict(role='user', content=message)
if isinstance(message, SystemMessage):
role = 'system'
elif isinstance(message, AIMessage):
role = 'assistant'
elif isinstance(message, HumanMessage):
role = 'user'
elif isinstance(message, FunctionMessage):
role = 'tool'
else:
raise ValueError(f'Unknown message type: {message}')
return dict(role=role, content=message.content)

def predict(self, messages, track_tokens=True):
t0 = time.time()
Expand All @@ -63,6 +80,7 @@ def predict_multiple(self, prompts):
"""Predicts multiple simple (non-message history) prompts using asyncio
if possible.
"""
prompts = [[self.convert_message(prompt)] for prompt in prompts]
try:
loop = asyncio.get_event_loop()
if not loop.is_running():
Expand Down
22 changes: 0 additions & 22 deletions heymans/model/_openai_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .. import config
from . import BaseModel
from langchain.schema import SystemMessage, AIMessage, HumanMessage, \
FunctionMessage


class OpenAIModel(BaseModel):
Expand All @@ -18,22 +16,6 @@ def __init__(self, heymans, model, **kwargs):
self._client = Client(api_key=config.openai_api_key)
self._async_client = AsyncClient(api_key=config.openai_api_key)

def convert_message(self, message):
# OpenAI expects messages as dict objects
if isinstance(message, str):
return dict(role='user', content=message)
if isinstance(message, SystemMessage):
role = 'system'
elif isinstance(message, AIMessage):
role = 'assistant'
elif isinstance(message, HumanMessage):
role = 'user'
elif isinstance(message, FunctionMessage):
role = 'tool'
else:
raise ValueError(f'Unknown message type: {message}')
return dict(role=role, content=message.content)

def predict(self, messages):
# Strings need to be converted a list of length one with a single
# message dict
Expand Down Expand Up @@ -61,10 +43,6 @@ def predict(self, messages):
message['tool_call_id'] = tool_call_id
return super().predict(messages)

def predict_multiple(self, prompts):
prompts = [[self.convert_message(prompt)] for prompt in prompts]
return super().predict_multiple(prompts)

def get_response(self, response):
tool_calls = response.choices[0].message.tool_calls
if tool_calls:
Expand Down
1 change: 1 addition & 0 deletions heymans/tools/_base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def name(self):
return self.__class__.__name__

def bind(self, args):
print(f'binding tool to: {args}')
if isinstance(args, str):
args = json.loads(args)
return functools.partial(self, **args)
Expand Down

0 comments on commit 7360b96

Please sign in to comment.