From 940e0983d341c84cd9ac7a81ab751e077bce2e72 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Wed, 4 Oct 2023 16:17:13 -0400 Subject: [PATCH] refactor!: return every message in full_round; richer returns in function call methods --- docs/api_reference.rst | 13 ++++ docs/customization.rst | 13 ++-- examples/3_customization_exception_prompt.py | 11 ++-- examples/colab_examples.ipynb | 17 +++--- kani/__init__.py | 1 + kani/internal.py | 26 ++++++++ kani/kani.py | 62 +++++++++++--------- kani/utils/cli.py | 8 +-- kani/utils/message_formatters.py | 28 +++++++++ 9 files changed, 129 insertions(+), 50 deletions(-) create mode 100644 kani/internal.py create mode 100644 kani/utils/message_formatters.py diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 74f1238..254a600 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -36,6 +36,14 @@ AI Function .. autoclass:: kani.AIParam :members: +Internals +--------- +.. autoclass:: kani.FunctionCallResult + :members: + +.. autoclass:: kani.ExceptionHandleResult + :members: + Engines ------- See :doc:`engine_reference`. @@ -45,3 +53,8 @@ Utilities .. autofunction:: kani.chat_in_terminal .. autofunction:: kani.chat_in_terminal_async + +Message Formatters +^^^^^^^^^^^^^^^^^^ +.. automodule:: kani.utils.message_formatters + :members: diff --git a/docs/customization.rst b/docs/customization.rst index 4188647..8f6a7b9 100644 --- a/docs/customization.rst +++ b/docs/customization.rst @@ -237,7 +237,7 @@ A requested function call can error out for a variety of reasons: By default, kani will add a :class:`.ChatMessage` to the chat history, giving the model feedback on what occurred. The model can then retry the call up to *retry_attempts* times. -:meth:`.Kani.handle_function_call_exception` controls this behaviour, adding the message and returning whether or not +:meth:`.Kani.handle_function_call_exception` controls this behaviour, returning the message to add and whether or not the model should be allowed to retry. By overriding this method, you can control the error prompt, log the error, or implement custom retry logic. @@ -258,15 +258,18 @@ Here's an example of providing custom prompts on an exception: `GitHub repo `__. .. code-block:: python - :emphasize-lines: 2-7 + :emphasize-lines: 2-10 class CustomExceptionPromptKani(Kani): async def handle_function_call_exception(self, call, err, attempt): - self.chat_history.append(ChatMessage.system( + # get the standard retry logic... + result = await super().handle_function_call_exception(call, err, attempt) + # but override the returned message with our own + result.message = ChatMessage.system( "The call encountered an error. " f"Relay this error message to the user in a sarcastic manner: {err}" - )) - return attempt < self.retry_attempts and err.retry + ) + return result @ai_function() def get_time(self): diff --git a/examples/3_customization_exception_prompt.py b/examples/3_customization_exception_prompt.py index a898e76..c04588e 100644 --- a/examples/3_customization_exception_prompt.py +++ b/examples/3_customization_exception_prompt.py @@ -13,12 +13,13 @@ class CustomExceptionPromptKani(Kani): async def handle_function_call_exception(self, call, err, attempt): - self.chat_history.append( - ChatMessage.system( - f"The call encountered an error. Relay this error message to the user in a sarcastic manner: {err}" - ) + # get the standard retry logic... + result = await super().handle_function_call_exception(call, err, attempt) + # but override the returned message with our own + result.message = ChatMessage.system( + f"The call encountered an error. Relay this error message to the user in a sarcastic manner: {err}" ) - return attempt < self.retry_attempts and err.retry + return result @ai_function() def get_time(self): diff --git a/examples/colab_examples.ipynb b/examples/colab_examples.ipynb index 5bff7c3..acfc68f 100644 --- a/examples/colab_examples.ipynb +++ b/examples/colab_examples.ipynb @@ -381,10 +381,13 @@ "\n", "class CustomExceptionKani(Kani):\n", " async def handle_function_call_exception(self, call, err, attempt):\n", - " self.chat_history.append(\n", - " ChatMessage.system(f\"The call encountered an error. Relay it to the user sarcastically: {err}\")\n", + " # get the standard retry logic...\n", + " result = await super().handle_function_call_exception(call, err, attempt)\n", + " # but override the returned message with our own\n", + " result.message = ChatMessage.system(\n", + " f\"The call encountered an error. Relay this error message to the user in a sarcastic manner: {err}\"\n", " )\n", - " return attempt < self.retry_attempts\n", + " return result\n", "\n", " @ai_function()\n", " def get_time(self):\n", @@ -662,7 +665,7 @@ "import collections\n", "import datetime\n", "\n", - "from kani import Kani, chat_in_terminal, ai_function, ChatMessage\n", + "from kani import Kani, chat_in_terminal, ai_function, ChatMessage, ExceptionHandleResult\n", "from kani.engines.openai import OpenAIEngine\n", "\n", "from kani.exceptions import FunctionCallException\n", @@ -675,8 +678,8 @@ " self.failed_calls = collections.Counter()\n", "\n", " async def handle_function_call_exception(self, call, err, attempt):\n", - " self.chat_history.append(ChatMessage.system(str(err)))\n", - " return attempt < self.retry_attempts\n", + " msg = ChatMessage.system(str(err))\n", + " return ExceptionHandleResult(should_retry=attempt < self.retry_attempts, message=msg)\n", "\n", " async def do_function_call(self, call):\n", " try:\n", @@ -728,4 +731,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/kani/__init__.py b/kani/__init__.py index 49b8cea..f11149b 100644 --- a/kani/__init__.py +++ b/kani/__init__.py @@ -1,5 +1,6 @@ from . import engines, exceptions, utils from .ai_function import AIFunction, AIParam, ai_function +from .internal import ExceptionHandleResult, FunctionCallResult from .kani import Kani from .models import ChatMessage, ChatRole, FunctionCall from .utils.cli import chat_in_terminal, chat_in_terminal_async diff --git a/kani/internal.py b/kani/internal.py new file mode 100644 index 0000000..79500fe --- /dev/null +++ b/kani/internal.py @@ -0,0 +1,26 @@ +from .models import ChatMessage + + +class FunctionCallResult: + """A model requested a function call, and the kani runtime resolved it.""" + + def __init__(self, is_model_turn: bool, message: ChatMessage): + """ + :param is_model_turn: True if the model should immediately react; False if the user speaks next. + :param message: The message containing the result of the function call, to add to the chat history. + """ + self.is_model_turn = is_model_turn + self.message = message + + +class ExceptionHandleResult: + """A function call raised an exception, and the kani runtime has prompted the model with exception information.""" + + def __init__(self, should_retry: bool, message: ChatMessage): + """ + :param should_retry: Whether the model should be allowed to retry the call that caused this exception. + :param message: The message containing details about the exception and/or instructions to retry, to add to the + chat history. + """ + self.should_retry = should_retry + self.message = message diff --git a/kani/kani.py b/kani/kani.py index 0e2e846..75b76fa 100644 --- a/kani/kani.py +++ b/kani/kani.py @@ -8,7 +8,9 @@ from .ai_function import AIFunction from .engines.base import BaseCompletion, BaseEngine from .exceptions import FunctionCallException, MessageTooLong, NoSuchFunction, WrappedCallException +from .internal import ExceptionHandleResult, FunctionCallResult from .models import ChatMessage, ChatRole, FunctionCall +from .utils.message_formatters import assistant_message_contents from .utils.typing import PathLike, SavedKani log = logging.getLogger("kani") @@ -22,11 +24,11 @@ class Kani: ``chat_round(query: str, **kwargs) -> ChatMessage`` - ``full_round(query: str, function_call_formatter: Callable[[ChatMessage], str], **kwargs) -> AsyncIterable[ChatMessage]`` + ``full_round(query: str, **kwargs) -> AsyncIterable[ChatMessage]`` ``chat_round_str(query: str, **kwargs) -> str`` - ``full_round_str(query: str, function_call_formatter: Callable[[ChatMessage], str], **kwargs) -> AsyncIterable[str]`` + ``full_round_str(query: str, message_formatter: Callable[[ChatMessage], str], **kwargs) -> AsyncIterable[str]`` **Function Calling** @@ -152,7 +154,8 @@ async def chat_round_str(self, query: str, **kwargs) -> str: async def full_round(self, query: str, **kwargs) -> AsyncIterable[ChatMessage]: """Perform a full chat round (user -> model [-> function -> model -> ...] -> user). - Yields each of the model's ChatMessages. A ChatMessage must have at least one of (content, function_call). + Yields each non-user ChatMessage created during the round. + A ChatMessage will have at least one of (content, function_call). Use this in an async for loop, like so:: @@ -179,12 +182,23 @@ async def full_round(self, query: str, **kwargs) -> AsyncIterable[ChatMessage]: return try: - is_model_turn = await self.do_function_call(message.function_call) + function_call_result = await self.do_function_call(message.function_call) + is_model_turn = function_call_result.is_model_turn + call_message = function_call_result.message + # save the result to the chat history + await self.add_to_history(call_message) + yield call_message except FunctionCallException as e: - should_retry = await self.handle_function_call_exception(message.function_call, e, retry) + exception_handling_result = await self.handle_function_call_exception( + message.function_call, e, retry + ) + # save the result to the chat history + exc_message = exception_handling_result.message + await self.add_to_history(exc_message) + yield exc_message # retry if we have retry attempts left retry += 1 - if not should_retry: + if not exception_handling_result.should_retry: # disable function calling on the next go kwargs = {**kwargs, "include_functions": False} continue @@ -194,23 +208,20 @@ async def full_round(self, query: str, **kwargs) -> AsyncIterable[ChatMessage]: async def full_round_str( self, query: str, - function_call_formatter: Callable[[ChatMessage], str | None] = lambda _: None, + message_formatter: Callable[[ChatMessage], str | None] = assistant_message_contents, **kwargs, ) -> AsyncIterable[str]: """Like :meth:`full_round`, but each yielded element is a str rather than a ChatMessage. :param query: The content of the user's chat message. - :param function_call_formatter: A function that returns a string to yield when the model decides to call a - function (or None to yield nothing). By default, ``full_round_str`` does not yield on a function call. + :param message_formatter: A function that returns a string to yield for each message. By default, ` + `full_round_str`` yields the content of each assistant message. :param kwargs: Additional arguments to pass to the model engine (e.g. hyperparameters). """ async for message in self.full_round(query, **kwargs): - if text := message.content: + if text := message_formatter(message): yield text - if message.function_call and (fn_msg := function_call_formatter(message)): - yield fn_msg - # ==== helpers ==== @property def always_len(self) -> int: @@ -312,7 +323,7 @@ async def get_prompt(self) -> list[ChatMessage]: return self.always_included_messages return self.always_included_messages + self.chat_history[-to_keep:] - async def do_function_call(self, call: FunctionCall) -> bool: + async def do_function_call(self, call: FunctionCall) -> FunctionCallResult: """Resolve a single function call. By default, any exception raised from this method will be an instance of a :class:`.FunctionCallException`. @@ -320,7 +331,8 @@ async def do_function_call(self, call: FunctionCall) -> bool: You may implement an override to add instrumentation around function calls (e.g. tracking success counts for varying prompts). See :ref:`do_function_call`. - :returns: True (default) if the model should immediately react; False if the user speaks next. + :returns: A :class:`.FunctionCallResult` including whose turn it is next and the message with the result of the + function call. :raises NoSuchFunction: The requested function does not exist. :raises WrappedCallException: The function raised an exception. """ @@ -347,17 +359,14 @@ async def do_function_call(self, call: FunctionCall) -> bool: ) msg = self._auto_truncate_message(msg, max_len=f.auto_truncate) log.debug(f"Auto truncate returned {self.message_token_len(msg)} tokens.") - # save the result to the chat history - await self.add_to_history(msg) - # yield whose turn it is - return f.after == ChatRole.ASSISTANT + return FunctionCallResult(is_model_turn=f.after == ChatRole.ASSISTANT, message=msg) async def handle_function_call_exception( self, call: FunctionCall, err: FunctionCallException, attempt: int - ) -> bool: + ) -> ExceptionHandleResult: """Called when a function call raises an exception. - By default, this adds a message to the chat telling the LM about the error and allows a retry if the error + By default, returns a message telling the LM about the error and allows a retry if the error is recoverable and there are remaining retry attempts. You may implement an override to customize the error prompt, log the error, or use custom retry logic. @@ -367,21 +376,20 @@ async def handle_function_call_exception( :param err: The error the call raised. Usually this is :class:`.NoSuchFunction` or :class:`.WrappedCallException`, although it may be any exception raised by :meth:`do_function_call`. :param attempt: The attempt number for the current call (0-indexed). - :returns: True if the model should retry the call; False if not. + :returns: A :class:`.ExceptionHandleResult` detailing whether the model should retry and the message to add to + the chat history. """ # log the exception here log.debug(f"Call to {call.name} raised an exception: {err}") # tell the model what went wrong if isinstance(err, NoSuchFunction): - await self.add_to_history( - ChatMessage.system(f"The function {err.name!r} is not defined. Only use the provided functions.") - ) + msg = ChatMessage.system(f"The function {err.name!r} is not defined. Only use the provided functions.") else: # but if it's a user function error, we want to raise it log.error(f"Call to {call.name} raised an exception: {err}", exc_info=err) - await self.add_to_history(ChatMessage.function(call.name, str(err))) + msg = ChatMessage.function(call.name, str(err)) - return attempt < self.retry_attempts and err.retry + return ExceptionHandleResult(should_retry=attempt < self.retry_attempts and err.retry, message=msg) async def add_to_history(self, message: ChatMessage): """Add the given message to the chat history. diff --git a/kani/utils/cli.py b/kani/utils/cli.py index a0e3e10..07f18d9 100644 --- a/kani/utils/cli.py +++ b/kani/utils/cli.py @@ -4,11 +4,7 @@ import os from kani.kani import Kani -from kani.models import ChatMessage - - -def _function_formatter(message: ChatMessage): - return f"Thinking ({message.function_call.name})..." +from kani.utils.message_formatters import assistant_message_contents_thinking async def chat_in_terminal_async(kani: Kani, rounds: int = 0, stopword: str = None): @@ -25,7 +21,7 @@ async def chat_in_terminal_async(kani: Kani, rounds: int = 0, stopword: str = No query = input("USER: ") if stopword and query == stopword: break - async for msg in kani.full_round_str(query, function_call_formatter=_function_formatter): + async for msg in kani.full_round_str(query, message_formatter=assistant_message_contents_thinking): print(f"AI: {msg}") except KeyboardInterrupt: pass diff --git a/kani/utils/message_formatters.py b/kani/utils/message_formatters.py new file mode 100644 index 0000000..91ead92 --- /dev/null +++ b/kani/utils/message_formatters.py @@ -0,0 +1,28 @@ +""" +A couple convenience formatters to customize :meth:`.Kani.full_round_str`. + +You can pass any of these functions in with, e.g., ``Kani.full_round_str(..., message_formatter=all_message_contents)``. +""" +from kani.models import ChatMessage, ChatRole + + +def all_message_contents(msg: ChatMessage): + """Return the content of any message.""" + return msg.content + + +def assistant_message_contents(msg: ChatMessage): + """Return the content of any assistant message; otherwise don't return anything.""" + if msg.role == ChatRole.ASSISTANT: + return msg.content + + +def assistant_message_contents_thinking(msg: ChatMessage): + """Return the content of any assistant message, and "Thinking..." on function calls.""" + if msg.role == ChatRole.ASSISTANT: + content = msg.content + if msg.function_call and content: + return f"{content}\n Thinking ({msg.function_call.name})..." + elif msg.function_call: + return f"Thinking ({msg.function_call.name})..." + return content