Skip to content

Commit

Permalink
refactor!: return every message in full_round; richer returns in func…
Browse files Browse the repository at this point in the history
…tion call methods
  • Loading branch information
zhudotexe committed Oct 4, 2023
1 parent 7a91dbf commit 940e098
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 50 deletions.
13 changes: 13 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ AI Function
.. autoclass:: kani.AIParam
:members:

Internals
---------
.. autoclass:: kani.FunctionCallResult
:members:

.. autoclass:: kani.ExceptionHandleResult
:members:

Engines
-------
See :doc:`engine_reference`.
Expand All @@ -45,3 +53,8 @@ Utilities
.. autofunction:: kani.chat_in_terminal

.. autofunction:: kani.chat_in_terminal_async

Message Formatters
^^^^^^^^^^^^^^^^^^
.. automodule:: kani.utils.message_formatters
:members:
13 changes: 8 additions & 5 deletions docs/customization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ A requested function call can error out for a variety of reasons:
By default, kani will add a :class:`.ChatMessage` to the chat history, giving the model feedback
on what occurred. The model can then retry the call up to *retry_attempts* times.

:meth:`.Kani.handle_function_call_exception` controls this behaviour, adding the message and returning whether or not
:meth:`.Kani.handle_function_call_exception` controls this behaviour, returning the message to add and whether or not
the model should be allowed to retry. By overriding this method, you can control the error prompt, log the error, or
implement custom retry logic.

Expand All @@ -258,15 +258,18 @@ Here's an example of providing custom prompts on an exception:
`GitHub repo <https://github.com/zhudotexe/kani/blob/main/examples/3_customization_custom_exception_prompt.py>`__.

.. code-block:: python
:emphasize-lines: 2-7
:emphasize-lines: 2-10
class CustomExceptionPromptKani(Kani):
async def handle_function_call_exception(self, call, err, attempt):
self.chat_history.append(ChatMessage.system(
# get the standard retry logic...
result = await super().handle_function_call_exception(call, err, attempt)
# but override the returned message with our own
result.message = ChatMessage.system(
"The call encountered an error. "
f"Relay this error message to the user in a sarcastic manner: {err}"
))
return attempt < self.retry_attempts and err.retry
)
return result
@ai_function()
def get_time(self):
Expand Down
11 changes: 6 additions & 5 deletions examples/3_customization_exception_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@

class CustomExceptionPromptKani(Kani):
async def handle_function_call_exception(self, call, err, attempt):
self.chat_history.append(
ChatMessage.system(
f"The call encountered an error. Relay this error message to the user in a sarcastic manner: {err}"
)
# get the standard retry logic...
result = await super().handle_function_call_exception(call, err, attempt)
# but override the returned message with our own
result.message = ChatMessage.system(
f"The call encountered an error. Relay this error message to the user in a sarcastic manner: {err}"
)
return attempt < self.retry_attempts and err.retry
return result

@ai_function()
def get_time(self):
Expand Down
17 changes: 10 additions & 7 deletions examples/colab_examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -381,10 +381,13 @@
"\n",
"class CustomExceptionKani(Kani):\n",
" async def handle_function_call_exception(self, call, err, attempt):\n",
" self.chat_history.append(\n",
" ChatMessage.system(f\"The call encountered an error. Relay it to the user sarcastically: {err}\")\n",
" # get the standard retry logic...\n",
" result = await super().handle_function_call_exception(call, err, attempt)\n",
" # but override the returned message with our own\n",
" result.message = ChatMessage.system(\n",
" f\"The call encountered an error. Relay this error message to the user in a sarcastic manner: {err}\"\n",
" )\n",
" return attempt < self.retry_attempts\n",
" return result\n",
"\n",
" @ai_function()\n",
" def get_time(self):\n",
Expand Down Expand Up @@ -662,7 +665,7 @@
"import collections\n",
"import datetime\n",
"\n",
"from kani import Kani, chat_in_terminal, ai_function, ChatMessage\n",
"from kani import Kani, chat_in_terminal, ai_function, ChatMessage, ExceptionHandleResult\n",
"from kani.engines.openai import OpenAIEngine\n",
"\n",
"from kani.exceptions import FunctionCallException\n",
Expand All @@ -675,8 +678,8 @@
" self.failed_calls = collections.Counter()\n",
"\n",
" async def handle_function_call_exception(self, call, err, attempt):\n",
" self.chat_history.append(ChatMessage.system(str(err)))\n",
" return attempt < self.retry_attempts\n",
" msg = ChatMessage.system(str(err))\n",
" return ExceptionHandleResult(should_retry=attempt < self.retry_attempts, message=msg)\n",
"\n",
" async def do_function_call(self, call):\n",
" try:\n",
Expand Down Expand Up @@ -728,4 +731,4 @@
]
}
]
}
}
1 change: 1 addition & 0 deletions kani/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from . import engines, exceptions, utils
from .ai_function import AIFunction, AIParam, ai_function
from .internal import ExceptionHandleResult, FunctionCallResult
from .kani import Kani
from .models import ChatMessage, ChatRole, FunctionCall
from .utils.cli import chat_in_terminal, chat_in_terminal_async
26 changes: 26 additions & 0 deletions kani/internal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from .models import ChatMessage


class FunctionCallResult:
"""A model requested a function call, and the kani runtime resolved it."""

def __init__(self, is_model_turn: bool, message: ChatMessage):
"""
:param is_model_turn: True if the model should immediately react; False if the user speaks next.
:param message: The message containing the result of the function call, to add to the chat history.
"""
self.is_model_turn = is_model_turn
self.message = message


class ExceptionHandleResult:
"""A function call raised an exception, and the kani runtime has prompted the model with exception information."""

def __init__(self, should_retry: bool, message: ChatMessage):
"""
:param should_retry: Whether the model should be allowed to retry the call that caused this exception.
:param message: The message containing details about the exception and/or instructions to retry, to add to the
chat history.
"""
self.should_retry = should_retry
self.message = message
62 changes: 35 additions & 27 deletions kani/kani.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from .ai_function import AIFunction
from .engines.base import BaseCompletion, BaseEngine
from .exceptions import FunctionCallException, MessageTooLong, NoSuchFunction, WrappedCallException
from .internal import ExceptionHandleResult, FunctionCallResult
from .models import ChatMessage, ChatRole, FunctionCall
from .utils.message_formatters import assistant_message_contents
from .utils.typing import PathLike, SavedKani

log = logging.getLogger("kani")
Expand All @@ -22,11 +24,11 @@ class Kani:
``chat_round(query: str, **kwargs) -> ChatMessage``
``full_round(query: str, function_call_formatter: Callable[[ChatMessage], str], **kwargs) -> AsyncIterable[ChatMessage]``
``full_round(query: str, **kwargs) -> AsyncIterable[ChatMessage]``
``chat_round_str(query: str, **kwargs) -> str``
``full_round_str(query: str, function_call_formatter: Callable[[ChatMessage], str], **kwargs) -> AsyncIterable[str]``
``full_round_str(query: str, message_formatter: Callable[[ChatMessage], str], **kwargs) -> AsyncIterable[str]``
**Function Calling**
Expand Down Expand Up @@ -152,7 +154,8 @@ async def chat_round_str(self, query: str, **kwargs) -> str:
async def full_round(self, query: str, **kwargs) -> AsyncIterable[ChatMessage]:
"""Perform a full chat round (user -> model [-> function -> model -> ...] -> user).
Yields each of the model's ChatMessages. A ChatMessage must have at least one of (content, function_call).
Yields each non-user ChatMessage created during the round.
A ChatMessage will have at least one of (content, function_call).
Use this in an async for loop, like so::
Expand All @@ -179,12 +182,23 @@ async def full_round(self, query: str, **kwargs) -> AsyncIterable[ChatMessage]:
return

try:
is_model_turn = await self.do_function_call(message.function_call)
function_call_result = await self.do_function_call(message.function_call)
is_model_turn = function_call_result.is_model_turn
call_message = function_call_result.message
# save the result to the chat history
await self.add_to_history(call_message)
yield call_message
except FunctionCallException as e:
should_retry = await self.handle_function_call_exception(message.function_call, e, retry)
exception_handling_result = await self.handle_function_call_exception(
message.function_call, e, retry
)
# save the result to the chat history
exc_message = exception_handling_result.message
await self.add_to_history(exc_message)
yield exc_message
# retry if we have retry attempts left
retry += 1
if not should_retry:
if not exception_handling_result.should_retry:
# disable function calling on the next go
kwargs = {**kwargs, "include_functions": False}
continue
Expand All @@ -194,23 +208,20 @@ async def full_round(self, query: str, **kwargs) -> AsyncIterable[ChatMessage]:
async def full_round_str(
self,
query: str,
function_call_formatter: Callable[[ChatMessage], str | None] = lambda _: None,
message_formatter: Callable[[ChatMessage], str | None] = assistant_message_contents,
**kwargs,
) -> AsyncIterable[str]:
"""Like :meth:`full_round`, but each yielded element is a str rather than a ChatMessage.
:param query: The content of the user's chat message.
:param function_call_formatter: A function that returns a string to yield when the model decides to call a
function (or None to yield nothing). By default, ``full_round_str`` does not yield on a function call.
:param message_formatter: A function that returns a string to yield for each message. By default, `
`full_round_str`` yields the content of each assistant message.
:param kwargs: Additional arguments to pass to the model engine (e.g. hyperparameters).
"""
async for message in self.full_round(query, **kwargs):
if text := message.content:
if text := message_formatter(message):
yield text

if message.function_call and (fn_msg := function_call_formatter(message)):
yield fn_msg

# ==== helpers ====
@property
def always_len(self) -> int:
Expand Down Expand Up @@ -312,15 +323,16 @@ async def get_prompt(self) -> list[ChatMessage]:
return self.always_included_messages
return self.always_included_messages + self.chat_history[-to_keep:]

async def do_function_call(self, call: FunctionCall) -> bool:
async def do_function_call(self, call: FunctionCall) -> FunctionCallResult:
"""Resolve a single function call.
By default, any exception raised from this method will be an instance of a :class:`.FunctionCallException`.
You may implement an override to add instrumentation around function calls (e.g. tracking success counts
for varying prompts). See :ref:`do_function_call`.
:returns: True (default) if the model should immediately react; False if the user speaks next.
:returns: A :class:`.FunctionCallResult` including whose turn it is next and the message with the result of the
function call.
:raises NoSuchFunction: The requested function does not exist.
:raises WrappedCallException: The function raised an exception.
"""
Expand All @@ -347,17 +359,14 @@ async def do_function_call(self, call: FunctionCall) -> bool:
)
msg = self._auto_truncate_message(msg, max_len=f.auto_truncate)
log.debug(f"Auto truncate returned {self.message_token_len(msg)} tokens.")
# save the result to the chat history
await self.add_to_history(msg)
# yield whose turn it is
return f.after == ChatRole.ASSISTANT
return FunctionCallResult(is_model_turn=f.after == ChatRole.ASSISTANT, message=msg)

async def handle_function_call_exception(
self, call: FunctionCall, err: FunctionCallException, attempt: int
) -> bool:
) -> ExceptionHandleResult:
"""Called when a function call raises an exception.
By default, this adds a message to the chat telling the LM about the error and allows a retry if the error
By default, returns a message telling the LM about the error and allows a retry if the error
is recoverable and there are remaining retry attempts.
You may implement an override to customize the error prompt, log the error, or use custom retry logic.
Expand All @@ -367,21 +376,20 @@ async def handle_function_call_exception(
:param err: The error the call raised. Usually this is :class:`.NoSuchFunction` or
:class:`.WrappedCallException`, although it may be any exception raised by :meth:`do_function_call`.
:param attempt: The attempt number for the current call (0-indexed).
:returns: True if the model should retry the call; False if not.
:returns: A :class:`.ExceptionHandleResult` detailing whether the model should retry and the message to add to
the chat history.
"""
# log the exception here
log.debug(f"Call to {call.name} raised an exception: {err}")
# tell the model what went wrong
if isinstance(err, NoSuchFunction):
await self.add_to_history(
ChatMessage.system(f"The function {err.name!r} is not defined. Only use the provided functions.")
)
msg = ChatMessage.system(f"The function {err.name!r} is not defined. Only use the provided functions.")
else:
# but if it's a user function error, we want to raise it
log.error(f"Call to {call.name} raised an exception: {err}", exc_info=err)
await self.add_to_history(ChatMessage.function(call.name, str(err)))
msg = ChatMessage.function(call.name, str(err))

return attempt < self.retry_attempts and err.retry
return ExceptionHandleResult(should_retry=attempt < self.retry_attempts and err.retry, message=msg)

async def add_to_history(self, message: ChatMessage):
"""Add the given message to the chat history.
Expand Down
8 changes: 2 additions & 6 deletions kani/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
import os

from kani.kani import Kani
from kani.models import ChatMessage


def _function_formatter(message: ChatMessage):
return f"Thinking ({message.function_call.name})..."
from kani.utils.message_formatters import assistant_message_contents_thinking


async def chat_in_terminal_async(kani: Kani, rounds: int = 0, stopword: str = None):
Expand All @@ -25,7 +21,7 @@ async def chat_in_terminal_async(kani: Kani, rounds: int = 0, stopword: str = No
query = input("USER: ")
if stopword and query == stopword:
break
async for msg in kani.full_round_str(query, function_call_formatter=_function_formatter):
async for msg in kani.full_round_str(query, message_formatter=assistant_message_contents_thinking):
print(f"AI: {msg}")
except KeyboardInterrupt:
pass
Expand Down
28 changes: 28 additions & 0 deletions kani/utils/message_formatters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
A couple convenience formatters to customize :meth:`.Kani.full_round_str`.
You can pass any of these functions in with, e.g., ``Kani.full_round_str(..., message_formatter=all_message_contents)``.
"""
from kani.models import ChatMessage, ChatRole


def all_message_contents(msg: ChatMessage):
"""Return the content of any message."""
return msg.content


def assistant_message_contents(msg: ChatMessage):
"""Return the content of any assistant message; otherwise don't return anything."""
if msg.role == ChatRole.ASSISTANT:
return msg.content


def assistant_message_contents_thinking(msg: ChatMessage):
"""Return the content of any assistant message, and "Thinking..." on function calls."""
if msg.role == ChatRole.ASSISTANT:
content = msg.content
if msg.function_call and content:
return f"{content}\n Thinking ({msg.function_call.name})..."
elif msg.function_call:
return f"Thinking ({msg.function_call.name})..."
return content

0 comments on commit 940e098

Please sign in to comment.