Skip to content

Commit

Permalink
Refactor tool use
Browse files Browse the repository at this point in the history
- Anthropic needs to be implemented
- Mistral tends to repeat tool uses
- OpenAI works well
  • Loading branch information
smathot committed Apr 12, 2024
1 parent 855fd0d commit 456960c
Show file tree
Hide file tree
Showing 31 changed files with 633 additions and 532 deletions.
2 changes: 1 addition & 1 deletion heymans/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""AI-based chatbot that provides sensible answers based on documentation"""

__version__ = '0.13.16'
__version__ = '0.14.0'
15 changes: 7 additions & 8 deletions heymans/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,9 @@
'''
# The default title of a new conversation
default_conversation_title = 'New conversation'
# The number of previous messages for which transient content should be
# retained. Transient content are large chunks of information that are included
# in AI messages, usually as the result of tool use.
keep_transient = 4
# The number of previous messages for which tool results should be
# retained.
keep_tool_results = 4

# RATE LIMITS
#
Expand Down Expand Up @@ -125,7 +124,7 @@
'answer_model': 'claude-3-opus'
},
'mistral': {
'search_model': 'mistral-medium',
'search_model': 'mistral-large',
'condense_model': 'mistral-medium',
'answer_model': 'mistral-large'
},
Expand All @@ -140,11 +139,11 @@
#
# Tools should match the names of classes from heymans.tools
# Search tools are executed in the first documentation-search phase
search_tools = ['TopicsTool', 'SearchTool']
search_tools = ['search_documentation']
# Answer tools are executed during the answer phase
answer_tools_with_search = []
answer_tools_without_search = ['CodeExecutionTool', 'GoogleScholarTool',
'AttachmentsTool', 'DownloadTool']
answer_tools_without_search = ['read_attachment', 'search_google_scholar',
'execute_code', 'download']

# SETTINGS
#
Expand Down
1 change: 1 addition & 0 deletions heymans/database/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def add_attachment(self, attachment_data: dict) -> int:
db.session.commit()
return attachment.attachment_id
except Exception as e:
breakpoint()
logger.error(f"Error adding attachment: {e}")
return -1

Expand Down
98 changes: 48 additions & 50 deletions heymans/heymans.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,6 @@ def __init__(self, user_id: str, persistent: bool = False,
]
self.documentation = Documentation(
self, sources=[FAISSDocumentationSource(self)])
self.search_model = model(self, self.model_config['search_model'])
self.answer_model = model(self, self.model_config['answer_model'])
self.condense_model = model(self, self.model_config['condense_model'])
self.messages = Messages(self, persistent)
if search_tools is None:
search_tools = config.search_tools
Expand All @@ -64,7 +61,23 @@ def __init__(self, user_id: str, persistent: bool = False,
# instantiated with heymans (self) as first argument
self.search_tools = [getattr(tools, t)(self) for t in search_tools]
self.answer_tools = [getattr(tools, t)(self) for t in answer_tools]
self.tools = self.answer_tools
# If there are search tools, the first one should always be used
if search_tools:
search_tool_choice = search_tools[0]
else:
search_tool_choice = None
# If there are answer tools, the mode can choose freely
if answer_tools:
answer_tool_choice = 'auto'
else:
answer_tool_choice = None
self.search_model = model(self, self.model_config['search_model'],
tools=self.search_tools,
tool_choice=search_tool_choice)
self.answer_model = model(self, self.model_config['answer_model'],
tools=self.answer_tools,
tool_choice=answer_tool_choice)
self.condense_model = model(self, self.model_config['condense_model'])

def send_user_message(self, message: str,
message_id: str=None) -> GeneratorType:
Expand Down Expand Up @@ -105,12 +118,14 @@ def _search(self, message: str) -> GeneratorType:
self.documentation.search([message])
# Then search based on the search-model queries derived from the user
# question
self.tools = self.search_tools
reply = self.search_model.predict(self.messages.prompt(
system_prompt=prompt.SYSTEM_PROMPT_SEARCH))
if config.log_replies:
logger.info(f'[search state] reply: {reply}')
self._run_tools(reply)
if callable(reply):
reply()
else:
logger.warning(f'[search state] did not call search tool')
self.documentation.strip_irrelevant(message)
logger.info(
f'[search state] {len(self.documentation._documents)} documents, {len(self.documentation)} characters')
Expand All @@ -120,7 +135,6 @@ def _answer(self, state: str = 'answer') -> GeneratorType:
yield {'action': 'set_loading_indicator',
'message': f'{config.ai_name} is thinking and typing '}, {}
logger.info(f'[{state} state] entering')
self.tools = self.answer_tools
# We first collect a regular reply to the user message. While doing so
# we also keep track of the number of tokens consumed.
tokens_consumed_before = self.answer_model.total_tokens_consumed
Expand All @@ -131,53 +145,37 @@ def _answer(self, state: str = 'answer') -> GeneratorType:
self.database.add_activity(tokens_consumed)
if config.log_replies:
logger.info(f'[{state} state] reply: {reply}')
# We then run tools based on the AI reply. This may modify the reply,
# mainly by stripping out any JSON commands in the reply
reply, result, needs_feedback = self._run_tools(reply)
if needs_feedback:
logger.info(f'[{state} state] tools need feedback')
# If the reply contains a NOT_DONE_YET marker, this is a way for the AI
# to indicate that it wants to perform additional actions. This makes
# it easier to perform tasks consisting of multiple responses and
# actions. The marker is stripped from the reply so that it's hidden
# from the user. We also check for a number of common linguistic
# indicators that the AI isn't done yet, such "I will now". This is
# necessary because the explicit marker isn't reliably sent.
if self.answer_model.supports_not_done_yet and \
prompt.NOT_DONE_YET_MARKER in reply:
reply = reply.replace(prompt.NOT_DONE_YET_MARKER, '')
logger.info(f'[{state} state] not-done-yet marker received')
needs_feedback = True
# If there is still a non-empty reply after running the tools (i.e.
# stripping the JSON hasn't cleared the reply entirely, then yield and
# remember it.
if reply:
# If the reply is a callable, then it's a tool that we need to run
if callable(reply):
tool_message, result, needs_feedback = reply()
if needs_feedback:
logger.info(f'[{state} state] tools need feedback')
metadata = self.messages.append('assistant', tool_message)
yield tool_message, metadata
# If the tool has a result, yield and remember it
if result:
metadata = self.messages.append('tool', result)
yield result, metadata
# Otherwise the reply is a regular AI message
else:
metadata = self.messages.append('assistant', reply)
yield reply, metadata
else:
metadata = self.messages.metadata()
# If the tools have a result, yield and remember it
if result:
self.messages.append('assistant', result)
yield result, metadata
# If the reply contains a NOT_DONE_YET marker, this is a way for the AI
# to indicate that it wants to perform additional actions. This makes
# it easier to perform tasks consisting of multiple responses and
# actions. The marker is stripped from the reply so that it's hidden
# from the user. We also check for a number of common linguistic
# indicators that the AI isn't done yet, such "I will now". This is
# necessary because the explicit marker isn't reliably sent.
if self.answer_model.supports_not_done_yet and \
prompt.NOT_DONE_YET_MARKER in reply:
reply = reply.replace(prompt.NOT_DONE_YET_MARKER, '')
logger.info(f'[{state} state] not-done-yet marker received')
needs_feedback = True
else:
needs_feedback = False
# If feedback is required, either because the tools require it or
# because the AI sent a NOT_DONE_YET marker, go for another round.
if needs_feedback and not self._rate_limit_exceeded():
for reply, metadata in self._answer(state='feedback'):
yield reply, metadata

def _run_tools(self, reply: str) -> Tuple[str, str, bool]:
"""Runs all tools on a reply. Returns the modified reply, a string
that concatenates all output (an empty string if no output was
produced) and a bool indicating whether the AI should in turn repond
to the produced output.
"""
logger.info(f'running tools')
results = []
needs_reply = []
for tool in self.tools:
reply, tool_results, tool_needs_reply = tool.run(reply)
if tool_results:
results += tool_results
needs_reply.append(tool_needs_reply)
return reply, '\n\n'.join(results), any(needs_reply)
39 changes: 14 additions & 25 deletions heymans/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,13 @@
from cryptography.fernet import InvalidToken
from .model import model
from . import prompt, config, utils, attachments
from langchain.schema import HumanMessage, AIMessage, SystemMessage
from langchain.schema import HumanMessage, AIMessage, SystemMessage, \
FunctionMessage
logger = logging.getLogger('heymans')


class Messages:

regex_transient = re.compile(r"<div class=['\"]+.*?transient.*?>.*?</div>",
re.DOTALL)

def __init__(self, heymans, persistent=False):
self._heymans = heymans
self._persistent = persistent
Expand Down Expand Up @@ -85,15 +83,9 @@ def delete(self, message_id):
if self._persistent:
self.save()

def _message_is_transient(self, content):
return self.regex_transient.search(content)

def prompt(self, system_prompt=None):
"""The prompt consists of the system prompt followed by a sequence of
AI and user messages. Transient messages are special messages that are
hidden except when they are the last message. This allows the AI to
feed some information to itself to respond to without confounding the
rest of the conversation.
AI, user, and tool/ function messages.
If no system prompt is provided, one is automatically constructed.
Typically, an explicit system_prompt is provided during the search
Expand All @@ -105,16 +97,15 @@ def prompt(self, system_prompt=None):
msg_len = len(self._condensed_message_history)
for msg_nr, (role, content) in enumerate(
self._condensed_message_history):
# Messages may contain transient content, such as attachment text,
# which are removed if they are a few messages away in the history.
# This avoid the prompt from becoming too large.
if msg_nr + config.keep_transient < msg_len:
if self._message_is_transient(content):
content = '<!--THIS MESSAGE IS NO LONGER AVAILABLE-->'
if role == 'assistant':
model_prompt.append(AIMessage(content=content))
elif role == 'user':
model_prompt.append(HumanMessage(content=content))
elif role == 'tool':
if msg_nr + config.keep_tool_results < msg_len:
continue
model_prompt.append(FunctionMessage(content=content,
name='tool_function'))
else:
raise ValueError(f'Invalid role: {role}')
return model_prompt
Expand All @@ -129,7 +120,7 @@ def _condense_message_history(self):
messages = [{"role": "system", "content": system_prompt}]
prompt_length = sum(len(content) for role, content
in self._condensed_message_history
if not self._message_is_transient(content))
if role != 'tool')
logger.info(f'system prompt length: {len(system_prompt)}')
logger.info(f'prompt length (without system prompt): {prompt_length}')
if prompt_length <= config.max_prompt_length:
Expand Down Expand Up @@ -169,24 +160,22 @@ def _system_prompt(self):
# For models that support this, there is also an instruction indicating
# that a special marker can be sent to indicate that the response isn't
# done yet. Not all models support this to avoid infinite loops.
if self._heymans.answer_model.supports_not_done_yet and \
self._heymans.tools:
if self._heymans.answer_model.supports_not_done_yet:
system_prompt.append(prompt.SYSTEM_PROMPT_NOT_DONE_YET)
# Each tool has an explanation
for tool in self._heymans.tools:
if tool.prompt:
system_prompt.append(tool.prompt)
# If available, documentation is also included in the prompt
if len(self._heymans.documentation):
system_prompt.append(self._heymans.documentation.prompt())
system_prompt.append(
attachments.attachments_prompt(self._heymans.database))
# And finally, if the message history has been condensed, this is also
# included.
if self._condensed_text:
logger.info('appending condensed text to system prompt')
system_prompt.append(prompt.render(
prompt.SYSTEM_PROMPT_CONDENSED,
summary=self._condensed_text))
return '\n\n'.join(system_prompt)
# Combine all non-empty prompt chunks
return '\n\n'.join(chunk for chunk in system_prompt if chunk.strip())

def _update_title(self):
"""The conversation title is updated when there are at least two
Expand Down
Loading

0 comments on commit 456960c

Please sign in to comment.