diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 5456ad1..62c26cb 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -16,6 +16,11 @@ Common Models :exclude-members: model_config, model_fields :class-doc-from: class +.. autoclass:: kani.ToolCall + :members: + :exclude-members: model_config, model_fields + :class-doc-from: class + .. autoclass:: kani.MessagePart :members: :exclude-members: model_config, model_fields diff --git a/docs/customization/chat_history.rst b/docs/customization/chat_history.rst index d5ac244..25ba2aa 100644 --- a/docs/customization/chat_history.rst +++ b/docs/customization/chat_history.rst @@ -31,8 +31,8 @@ For example, here's how you might extend :meth:`.Kani.add_to_history` to log eve super().__init__(*args, **kwargs) self.log_file = open("kani-log.jsonl", "w") - async def add_to_history(self, message): - await super().add_to_history(message) + async def add_to_history(self, message, *args, **kwargs): + await super().add_to_history(message, *args, **kwargs) self.log_file.write(message.model_dump_json()) self.log_file.write("\n") diff --git a/docs/customization/function_call.rst b/docs/customization/function_call.rst index 955dba1..35aefd2 100644 --- a/docs/customization/function_call.rst +++ b/docs/customization/function_call.rst @@ -39,9 +39,9 @@ during a conversation, and how often it was successful: self.successful_calls = collections.Counter() self.failed_calls = collections.Counter() - async def do_function_call(self, call): + async def do_function_call(self, call, *args, **kwargs): try: - result = await super().do_function_call(call) + result = await super().do_function_call(call, *args, **kwargs) self.successful_calls[call.name] += 1 return result except FunctionCallException: diff --git a/docs/customization/function_exception.rst b/docs/customization/function_exception.rst index d8f3099..9536063 100644 --- a/docs/customization/function_exception.rst +++ b/docs/customization/function_exception.rst @@ -41,9 +41,9 @@ Here's an example of providing custom prompts on an exception: :emphasize-lines: 2-10 class CustomExceptionPromptKani(Kani): - async def handle_function_call_exception(self, call, err, attempt): + async def handle_function_call_exception(self, call, err, attempt, *args, **kwargs): # get the standard retry logic... - result = await super().handle_function_call_exception(call, err, attempt) + result = await super().handle_function_call_exception(call, err, attempt, *args, **kwargs) # but override the returned message with our own result.message = ChatMessage.system( "The call encountered an error. " diff --git a/docs/engines/implementing.rst b/docs/engines/implementing.rst index 470bcdd..d6a8670 100644 --- a/docs/engines/implementing.rst +++ b/docs/engines/implementing.rst @@ -48,6 +48,10 @@ You'll need to implement two methods: :meth:`.BaseEngine.predict` and :meth:`.Ba build such a prompt. :meth:`.BaseEngine.function_token_reserve` tells kani how many tokens that prompt takes, so the context window management can ensure it never sends too many tokens. +You'll also need to add previous function calls into the prompt (e.g. in the few-shot function calling example). +When you're building the prompt, you'll need to iterate over :attr:`.ChatMessage.tool_calls` if it exists, and add +your model's appropriate function calling prompt. + To parse the model's requests to call a function, you also do this in :meth:`.BaseEngine.predict`. After generating the model's completion (usually a string, or a list of token IDs that decodes into a string), separate the model's conversational content from the structured function call: @@ -56,4 +60,7 @@ conversational content from the structured function call: :align: center Finally, return a :class:`.Completion` with the ``.message`` attribute set to a :class:`.ChatMessage` with the -appropriate :attr:`.ChatMessage.content` and :attr:`.ChatMessage.function_call`. +appropriate :attr:`.ChatMessage.content` and :attr:`.ChatMessage.tool_calls`. + +.. note:: + See :ref:`functioncall_v_toolcall` for more information about ToolCalls vs FunctionCalls. diff --git a/docs/function_calling.rst b/docs/function_calling.rst index be5a847..94404f3 100644 --- a/docs/function_calling.rst +++ b/docs/function_calling.rst @@ -137,6 +137,10 @@ Next Actor After a function call returns, kani will hand control back to the LM to generate a response by default. If instead control should be given to the human (i.e. return from the chat round), set ``after=ChatRole.USER``. +.. note:: + If the model calls multiple tools in parallel, the model will be allowed to generate a response if *any* function + has ``after=ChatRole.ASSISTANT`` (the default) once all function calls are complete. + Complete Example ---------------- Here's the full example of how you might implement a function to get weather that we built in the last few steps: @@ -182,38 +186,77 @@ prompt a model, we can mock these returns in the chat history using :meth:`.Chat For example, here's how you might prompt the model to give the temperature in both Fahrenheit and Celsius without the user having to ask: -.. code-block:: python - - from kani import ChatMessage, FunctionCall - fewshot = [ - ChatMessage.user("What's the weather in Philadelphia?"), - # first, the model should ask for the weather in fahrenheit - ChatMessage.assistant( - content=None, - function_call=FunctionCall.with_args( - "get_weather", location="Philadelphia, PA", unit="fahrenheit" - ) - ), - # and we mock the function's response to the model - ChatMessage.function( - "get_weather", - "Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit.", - ), - # repeat in celsius - ChatMessage.assistant( - content=None, - function_call=FunctionCall.with_args( - "get_weather", location="Philadelphia, PA", unit="celsius" - ) - ), - ChatMessage.function( - "get_weather", - "Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius.", - ), - # finally, give the result to the user - ChatMessage.assistant("It's currently 85F (29C) and partly cloudy in Philadelphia."), - ] - ai = MyKani(engine, chat_history=fewshot) +.. tab:: ToolCall API + + .. code-block:: python + + # build the chat history with examples + fewshot = [ + ChatMessage.user("What's the weather in Philadelphia?"), + ChatMessage.assistant( + content=None, + # use a walrus operator to save a reference to the tool call here... + tool_calls=[ + tc := ToolCall.from_function("get_weather", location="Philadelphia, PA", unit="fahrenheit") + ], + ), + ChatMessage.function( + "get_weather", + "Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit.", + # ...so this function result knows which call it's responding to + tc.id + ), + # and repeat for the other unit + ChatMessage.assistant( + content=None, + tool_calls=[ + tc2 := ToolCall.from_function("get_weather", location="Philadelphia, PA", unit="celsius") + ], + ), + ChatMessage.function( + "get_weather", + "Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius.", + tc2.id + ), + ChatMessage.assistant("It's currently 85F (29C) and partly cloudy in Philadelphia."), + ] + # and give it to the kani when you initialize it + ai = MyKani(engine, chat_history=fewshot) + +.. tab:: FunctionCall API (deprecated) + + .. code-block:: python + + from kani import ChatMessage, FunctionCall + fewshot = [ + ChatMessage.user("What's the weather in Philadelphia?"), + # first, the model should ask for the weather in fahrenheit + ChatMessage.assistant( + content=None, + function_call=FunctionCall.with_args( + "get_weather", location="Philadelphia, PA", unit="fahrenheit" + ) + ), + # and we mock the function's response to the model + ChatMessage.function( + "get_weather", + "Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit.", + ), + # repeat in celsius + ChatMessage.assistant( + content=None, + function_call=FunctionCall.with_args( + "get_weather", location="Philadelphia, PA", unit="celsius" + ) + ), + ChatMessage.function( + "get_weather", + "Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius.", + ), + # finally, give the result to the user + ChatMessage.assistant("It's currently 85F (29C) and partly cloudy in Philadelphia."), + ] + ai = MyKani(engine, chat_history=fewshot) .. code-block:: pycon @@ -254,4 +297,26 @@ passing params with invalid, non-coercible types) or the function raises an exce error in a message to the model by default, allowing it up to *retry_attempts* to correct itself and retry the call. +.. note:: + If the model calls multiple tools in parallel, the model will be allowed a retry if *any* exception handler + allows it. This will only count as 1 retry attempt regardless of the number of functions that raised an exception. + In the next section, we'll discuss how to customize this behaviour, along with other parts of the kani interface. + +.. _functioncall_v_toolcall: + +Internal Representation +----------------------- + +.. versionchanged:: v0.6.0 + +As of Nov 6, 2023, OpenAI added the ability for a single assistant message to request calling multiple functions in +parallel, and wrapped all function calls in a :class:`.ToolCall` wrapper. In order to add support for this in kani while +maintaining backwards compatibility with OSS function calling models, a :class:`.ChatMessage` actually maintains the +following internal representation: + +:attr:`.ChatMessage.function_call` is actually an alias for ``ChatMessage.tool_calls[0].function``. If there is more +than one tool call in the message, kani will raise an exception. + +A ToolCall is effectively a named wrapper around a :class:`.FunctionCall`, associating the request with a generated +ID so that its response can be linked to the request in future rounds of prompting. diff --git a/examples/2_function_calling_fewshot.py b/examples/2_function_calling_fewshot.py index d1f5b29..981b00f 100644 --- a/examples/2_function_calling_fewshot.py +++ b/examples/2_function_calling_fewshot.py @@ -2,7 +2,7 @@ import os from typing import Annotated -from kani import AIParam, ChatMessage, FunctionCall, Kani, ai_function, chat_in_terminal +from kani import AIParam, ChatMessage, Kani, ToolCall, ai_function, chat_in_terminal from kani.engines.openai import OpenAIEngine api_key = os.getenv("OPENAI_API_KEY") @@ -35,14 +35,16 @@ def get_weather( ChatMessage.user("What's the weather in Philadelphia?"), ChatMessage.assistant( content=None, - function_call=FunctionCall.with_args("get_weather", location="Philadelphia, PA", unit="fahrenheit"), + # use a walrus operator to save a reference to the tool call here... + tool_calls=[tc := ToolCall.from_function("get_weather", location="Philadelphia, PA", unit="fahrenheit")], ), - ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit."), + # so this function result knows which call it's responding to + ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit.", tc.id), ChatMessage.assistant( content=None, - function_call=FunctionCall.with_args("get_weather", location="Philadelphia, PA", unit="celsius"), + tool_calls=[tc2 := ToolCall.from_function("get_weather", location="Philadelphia, PA", unit="celsius")], ), - ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius."), + ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius.", tc2.id), ChatMessage.assistant("It's currently 85F (29C) and partly cloudy in Philadelphia."), ] # and give it to the kani when you initialize it diff --git a/examples/3_customization_exception_prompt.py b/examples/3_customization_exception_prompt.py index 3d93212..bf330a6 100644 --- a/examples/3_customization_exception_prompt.py +++ b/examples/3_customization_exception_prompt.py @@ -13,9 +13,9 @@ class CustomExceptionPromptKani(Kani): - async def handle_function_call_exception(self, call, err, attempt): + async def handle_function_call_exception(self, call, err, attempt, *args, **kwargs): # get the standard retry logic... - result = await super().handle_function_call_exception(call, err, attempt) + result = await super().handle_function_call_exception(call, err, attempt, *args, **kwargs) # but override the returned message with our own result.message = ChatMessage.system( f"The call encountered an error. Relay this error message to the user in a sarcastic manner: {err}" diff --git a/examples/3_customization_track_function_calls.py b/examples/3_customization_track_function_calls.py index 1d05e66..72a0e55 100644 --- a/examples/3_customization_track_function_calls.py +++ b/examples/3_customization_track_function_calls.py @@ -22,9 +22,9 @@ def __init__(self, *args, **kwargs): self.successful_calls = collections.Counter() self.failed_calls = collections.Counter() - async def do_function_call(self, call): + async def do_function_call(self, call, *args, **kwargs): try: - result = await super().do_function_call(call) + result = await super().do_function_call(call, *args, **kwargs) self.successful_calls[call.name] += 1 return result except FunctionCallException: diff --git a/examples/colab_examples.ipynb b/examples/colab_examples.ipynb index acfc68f..ac170c4 100644 --- a/examples/colab_examples.ipynb +++ b/examples/colab_examples.ipynb @@ -380,9 +380,9 @@ "\n", "\n", "class CustomExceptionKani(Kani):\n", - " async def handle_function_call_exception(self, call, err, attempt):\n", + " async def handle_function_call_exception(self, call, err, attempt, *args, **kwargs):\n", " # get the standard retry logic...\n", - " result = await super().handle_function_call_exception(call, err, attempt)\n", + " result = await super().handle_function_call_exception(call, err, attempt, *args, **kwargs)\n", " # but override the returned message with our own\n", " result.message = ChatMessage.system(\n", " f\"The call encountered an error. Relay this error message to the user in a sarcastic manner: {err}\"\n", @@ -677,13 +677,13 @@ " self.successful_calls = collections.Counter()\n", " self.failed_calls = collections.Counter()\n", "\n", - " async def handle_function_call_exception(self, call, err, attempt):\n", - " msg = ChatMessage.system(str(err))\n", + " async def handle_function_call_exception(self, call, err, attempt, tool_call_id=None):\n", + " msg = ChatMessage.function(name=call.name, content=str(err), tool_call_id=tool_call_id)\n", " return ExceptionHandleResult(should_retry=attempt < self.retry_attempts, message=msg)\n", "\n", - " async def do_function_call(self, call):\n", + " async def do_function_call(self, call, *args, **kwargs):\n", " try:\n", - " res = await super().do_function_call(call)\n", + " res = await super().do_function_call(call, *args, **kwargs)\n", " self.successful_calls[call.name] += 1\n", " return res\n", " except FunctionCallException:\n", diff --git a/kani/__init__.py b/kani/__init__.py index 20b748f..ebfbcb9 100644 --- a/kani/__init__.py +++ b/kani/__init__.py @@ -2,7 +2,7 @@ from .ai_function import AIFunction, AIParam, ai_function from .internal import ExceptionHandleResult, FunctionCallResult from .kani import Kani -from .models import ChatMessage, ChatRole, FunctionCall, MessagePart +from .models import ChatMessage, ChatRole, FunctionCall, MessagePart, ToolCall from .utils.cli import chat_in_terminal, chat_in_terminal_async # declare that kani is also a namespace package diff --git a/kani/engines/openai/client.py b/kani/engines/openai/client.py index fe6e8cb..8b5d8ee 100644 --- a/kani/engines/openai/client.py +++ b/kani/engines/openai/client.py @@ -1,10 +1,20 @@ import asyncio +import warnings from typing import Literal, overload import aiohttp import pydantic -from .models import ChatCompletion, Completion, FunctionSpec, OpenAIChatMessage, SpecificFunctionCall +from .models import ( + ChatCompletion, + Completion, + FunctionSpec, + OpenAIChatMessage, + ResponseFormat, + SpecificFunctionCall, + ToolChoice, + ToolSpec, +) from ..httpclient import BaseClient, HTTPException, HTTPStatusException, HTTPTimeout @@ -77,6 +87,7 @@ async def request(self, method: str, route: str, headers=None, retry=None, **kwa async def create_completion( self, model: str, + *, prompt: str = "<|endoftext|>", suffix: str = None, max_tokens: int = 16, @@ -85,6 +96,7 @@ async def create_completion( n: int = 1, logprobs: int = None, echo: bool = False, + seed: int | None = None, stop: str | list[str] = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, @@ -107,24 +119,31 @@ async def create_chat_completion( self, model: str, messages: list[OpenAIChatMessage], - functions: list[FunctionSpec] | None = None, - function_call: SpecificFunctionCall | Literal["auto"] | Literal["none"] | None = None, + *, + tools: list[ToolSpec] | None = None, + tool_choice: ToolChoice | Literal["auto"] | Literal["none"] | None = None, temperature: float = 1.0, top_p: float = 1.0, n: int = 1, + response_format: ResponseFormat | None = None, + seed: int | None = None, stop: str | list[str] | None = None, max_tokens: int | None = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, logit_bias: dict | None = None, user: str | None = None, + # deprecated + functions: list[FunctionSpec] | None = None, + function_call: SpecificFunctionCall | Literal["auto"] | Literal["none"] | None = None, ) -> ChatCompletion: ... async def create_chat_completion( self, model: str, messages: list[OpenAIChatMessage], - functions: list[FunctionSpec] | None = None, + *, + tools: list[ToolSpec] | None = None, **kwargs, ) -> ChatCompletion: """Create a chat completion. @@ -132,9 +151,16 @@ async def create_chat_completion( See https://platform.openai.com/docs/api-reference/chat/create. """ # transform pydantic models - if functions: - kwargs["functions"] = [f.model_dump(exclude_unset=True) for f in functions] + if tools: + kwargs["tools"] = [t.model_dump(exclude_unset=True) for t in tools] + if "tool_choice" in kwargs and isinstance(kwargs["tool_choice"], SpecificFunctionCall): + kwargs["tool_choice"] = kwargs["tool_choice"].model_dump(exclude_unset=True) + # deprecated function calling + if "functions" in kwargs: + warnings.warn("The functions parameter is deprecated. Use tools instead.", DeprecationWarning) + kwargs["functions"] = [f.model_dump(exclude_unset=True) for f in kwargs["functions"]] if "function_call" in kwargs and isinstance(kwargs["function_call"], SpecificFunctionCall): + warnings.warn("The function_call parameter is deprecated. Use tool_choice instead.", DeprecationWarning) kwargs["function_call"] = kwargs["function_call"].model_dump(exclude_unset=True) # call API data = await self.post( diff --git a/kani/engines/openai/engine.py b/kani/engines/openai/engine.py index 9e068df..933604e 100644 --- a/kani/engines/openai/engine.py +++ b/kani/engines/openai/engine.py @@ -2,11 +2,11 @@ import os from kani.ai_function import AIFunction -from kani.exceptions import MissingModelDependencies -from kani.models import ChatMessage +from kani.exceptions import MissingModelDependencies, ToolCallError +from kani.models import ChatMessage, ChatRole from . import function_calling from .client import OpenAIClient -from .models import ChatCompletion, FunctionSpec, OpenAIChatMessage +from .models import ChatCompletion, FunctionSpec, OpenAIChatMessage, ToolSpec from ..base import BaseEngine try: @@ -114,12 +114,48 @@ async def predict( self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams ) -> ChatCompletion: if functions: - function_spec = [FunctionSpec(name=f.name, description=f.desc, parameters=f.json_schema) for f in functions] + tool_specs = [ + ToolSpec.from_function(FunctionSpec(name=f.name, description=f.desc, parameters=f.json_schema)) + for f in functions + ] else: - function_spec = None - translated_messages = [OpenAIChatMessage.from_chatmessage(m) for m in messages] + tool_specs = None + # translate to openai spec - group any tool messages together and ensure all free ToolCall IDs are bound + translated_messages = [] + free_toolcall_ids = set() + for m in messages: + # if this is not a function result and there are free tool call IDs, raise + if m.role != ChatRole.FUNCTION and free_toolcall_ids: + raise ToolCallError( + f"Encountered a {m.role.value!r} message but expected a FUNCTION message to satisfy the pending" + f" tool call(s): {free_toolcall_ids}" + ) + # asst: add tool call IDs to freevars + if m.role == ChatRole.ASSISTANT and m.tool_calls: + for tc in m.tool_calls: + free_toolcall_ids.add(tc.id) + # func: bind freevars + elif m.role == ChatRole.FUNCTION: + # has ID: bind it + if m.tool_call_id is not None and m.tool_call_id in free_toolcall_ids: + free_toolcall_ids.remove(m.tool_call_id) + # no ID: bind if unambiguous, otherwise cry + elif m.tool_call_id is None: + if len(free_toolcall_ids) == 1: + m = m.copy_with(tool_call_id=free_toolcall_ids.pop()) + elif len(free_toolcall_ids) > 1: + raise ToolCallError( + "Got a FUNCTION message with no tool_call_id but multiple tool calls are pending" + f" ({free_toolcall_ids})! Set the tool_call_id to resolve the pending tool requests." + ) + translated_messages.append(OpenAIChatMessage.from_chatmessage(m)) + # if the translated messages start with a hanging TOOL call, strip it (openai limitation) + # though hanging FUNCTION messages are OK + while translated_messages and translated_messages[0].role == "tool": + translated_messages.pop(0) + # make API call completion = await self.client.create_chat_completion( - model=self.model, messages=translated_messages, functions=function_spec, **self.hyperparams, **hyperparams + model=self.model, messages=translated_messages, tools=tool_specs, **self.hyperparams, **hyperparams ) return completion diff --git a/kani/engines/openai/models.py b/kani/engines/openai/models.py index e8dacf7..69fa5af 100644 --- a/kani/engines/openai/models.py +++ b/kani/engines/openai/models.py @@ -1,6 +1,6 @@ from typing import Literal -from kani.models import BaseModel, ChatMessage, ChatRole, FunctionCall +from kani.models import BaseModel, ChatMessage, ChatRole, FunctionCall, ToolCall from ..base import BaseCompletion @@ -30,6 +30,7 @@ class Completion(BaseModel): object: Literal["text_completion"] created: int model: str + system_fingerprint: str | None = None choices: list[CompletionChoice] usage: CompletionUsage @@ -46,25 +47,79 @@ class FunctionSpec(BaseModel): parameters: dict +class ToolSpec(BaseModel): + type: str + function: FunctionSpec + + @classmethod + def from_function(cls, spec: FunctionSpec): + return cls(type="function", function=spec) + + class SpecificFunctionCall(BaseModel): name: str +class ToolChoice(BaseModel): + type: str + function: SpecificFunctionCall + + @classmethod + def from_function(cls, name: str): + return cls(type="function", function=SpecificFunctionCall(name=name)) + + +class ResponseFormat(BaseModel): + type: str + + @classmethod + def text(cls): + return cls(type="text") + + @classmethod + def json_object(cls): + return cls(type="json_object") + + class OpenAIChatMessage(BaseModel): - role: ChatRole + role: str content: str | list[BaseModel | str] | None name: str | None = None + tool_call_id: str | None = None + tool_calls: list[ToolCall] | None = None + # deprecated function_call: FunctionCall | None = None @classmethod def from_chatmessage(cls, m: ChatMessage): - return cls(role=m.role, content=m.text, name=m.name, function_call=m.function_call) + # translate tool responses to a function to the right openai format + if m.role == ChatRole.FUNCTION: + if m.tool_call_id is not None: + return cls(role="tool", content=m.text, name=m.name, tool_call_id=m.tool_call_id) + return cls(role=m.role.value, content=m.text, name=m.name) + return cls(role=m.role.value, content=m.text, name=m.name, tool_call_id=m.tool_call_id, tool_calls=m.tool_calls) + + def to_chatmessage(self) -> ChatMessage: + # translate tool role to function role + if self.role == "tool": + role = ChatRole.FUNCTION + else: + role = ChatRole(self.role) + # translate FunctionCall to singular ToolCall + if self.tool_calls: + tool_calls = self.tool_calls + elif self.function_call: + tool_calls = [ToolCall.from_function_call(self.function_call)] + else: + tool_calls = None + return ChatMessage( + role=role, content=self.content, name=self.name, tool_call_id=self.tool_call_id, tool_calls=tool_calls + ) # ---- response ---- class ChatCompletionChoice(BaseModel): - # this is a ChatMessage rather than an OpenAIChatMessage because all engines need to return the kani model - message: ChatMessage + message: OpenAIChatMessage index: int finish_reason: str | None = None @@ -74,12 +129,13 @@ class ChatCompletion(BaseCompletion, BaseModel): object: Literal["chat.completion"] created: int model: str + system_fingerprint: str | None = None usage: CompletionUsage choices: list[ChatCompletionChoice] @property - def message(self): - return self.choices[0].message + def message(self) -> ChatMessage: + return self.choices[0].message.to_chatmessage() @property def prompt_tokens(self): diff --git a/kani/exceptions.py b/kani/exceptions.py index fffc6a3..23531d9 100644 --- a/kani/exceptions.py +++ b/kani/exceptions.py @@ -63,6 +63,10 @@ class MissingModelDependencies(KaniException): """You are trying to use an engine but do not have engine-specific packages installed.""" +class ToolCallError(KaniException): + """Something went wrong with tool calls.""" + + # ==== serdes ==== class MissingMessagePartType(KaniException): """During loading a saved kani, a message part has a type which is not currently defined in the runtime.""" diff --git a/kani/internal.py b/kani/internal.py index 79500fe..6d2faf6 100644 --- a/kani/internal.py +++ b/kani/internal.py @@ -1,7 +1,13 @@ +import abc + from .models import ChatMessage -class FunctionCallResult: +class HasMessage(abc.ABC): + message: ChatMessage + + +class FunctionCallResult(HasMessage): """A model requested a function call, and the kani runtime resolved it.""" def __init__(self, is_model_turn: bool, message: ChatMessage): @@ -13,7 +19,7 @@ def __init__(self, is_model_turn: bool, message: ChatMessage): self.message = message -class ExceptionHandleResult: +class ExceptionHandleResult(HasMessage): """A function call raised an exception, and the kani runtime has prompted the model with exception information.""" def __init__(self, should_retry: bool, message: ChatMessage): diff --git a/kani/kani.py b/kani/kani.py index 55269ab..b6fe03a 100644 --- a/kani/kani.py +++ b/kani/kani.py @@ -9,7 +9,7 @@ from .engines.base import BaseCompletion, BaseEngine from .exceptions import FunctionCallException, MessageTooLong, NoSuchFunction, WrappedCallException from .internal import ExceptionHandleResult, FunctionCallResult -from .models import ChatMessage, ChatRole, FunctionCall, QueryType +from .models import ChatMessage, ChatRole, FunctionCall, QueryType, ToolCall from .utils.message_formatters import assistant_message_contents from .utils.typing import PathLike, SavedKani @@ -80,7 +80,7 @@ def __init__( Use ``chat_history=mykani.chat_history.copy()`` to pass a copy. :param functions: A list of :class:`.AIFunction` to expose to the model (for dynamic function calling). Use :func:`.ai_function` to define static functions (see :doc:`function_calling`). - :param retry_attempts: How many attempts the LM may take if a function call raises an exception. + :param retry_attempts: How many attempts the LM may take per full round if any tool call raises an exception. """ self.engine = engine self.system_prompt = system_prompt.strip() if system_prompt else None @@ -178,30 +178,40 @@ async def full_round(self, query: QueryType, **kwargs) -> AsyncIterable[ChatMess yield message # if function call, do it and attempt retry if it's wrong - if not message.function_call: + if not message.tool_calls: return - try: - function_call_result = await self.do_function_call(message.function_call) - is_model_turn = function_call_result.is_model_turn - call_message = function_call_result.message + # run each tool call in parallel + async def _do_tool_call(tc: ToolCall): + try: + return await self.do_function_call(tc.function, tool_call_id=tc.id) + except FunctionCallException as e: + return await self.handle_function_call_exception(tc.function, e, retry, tool_call_id=tc.id) + + # and update results after they are completed + is_model_turn = False + should_retry_call = False + n_errs = 0 + results = await asyncio.gather(*(_do_tool_call(tc) for tc in message.tool_calls)) + for result in results: # save the result to the chat history - await self.add_to_history(call_message) - yield call_message - except FunctionCallException as e: - exception_handling_result = await self.handle_function_call_exception( - message.function_call, e, retry - ) - # save the result to the chat history - exc_message = exception_handling_result.message - await self.add_to_history(exc_message) - yield exc_message - # retry if we have retry attempts left + await self.add_to_history(result.message) + yield result.message + if isinstance(result, ExceptionHandleResult): + is_model_turn = True + n_errs += 1 + # retry if any function says so + should_retry_call = should_retry_call or result.should_retry + else: + # allow model to generate response if any function says so + is_model_turn = is_model_turn or result.is_model_turn + + # if we encountered an error, increment the retry counter and allow the model to generate a response + if n_errs: retry += 1 - if not exception_handling_result.should_retry: + if not should_retry_call: # disable function calling on the next go - kwargs = {**kwargs, "include_functions": False} - continue + kwargs["include_functions"] = False else: retry = 0 @@ -323,7 +333,7 @@ async def get_prompt(self) -> list[ChatMessage]: return self.always_included_messages return self.always_included_messages + self.chat_history[-to_keep:] - async def do_function_call(self, call: FunctionCall) -> FunctionCallResult: + async def do_function_call(self, call: FunctionCall, tool_call_id: str = None) -> FunctionCallResult: """Resolve a single function call. By default, any exception raised from this method will be an instance of a :class:`.FunctionCallException`. @@ -331,6 +341,8 @@ async def do_function_call(self, call: FunctionCall) -> FunctionCallResult: You may implement an override to add instrumentation around function calls (e.g. tracking success counts for varying prompts). See :doc:`/customization/function_call`. + :param call: The name of the function to call and arguments to call it with. + :param tool_call_id: The ``tool_call_id`` to set in the returned FUNCTION message. :returns: A :class:`.FunctionCallResult` including whose turn it is next and the message with the result of the function call. :raises NoSuchFunction: The requested function does not exist. @@ -348,7 +360,7 @@ async def do_function_call(self, call: FunctionCall) -> FunctionCallResult: log.debug(f"{f.name} responded with data: {result_str!r}") except Exception as e: raise WrappedCallException(f.auto_retry, e) from e - msg = ChatMessage.function(f.name, result_str) + msg = ChatMessage.function(f.name, result_str, tool_call_id=tool_call_id) # if we are auto truncating, check and see if we need to if f.auto_truncate is not None: message_len = self.message_token_len(msg) @@ -362,7 +374,7 @@ async def do_function_call(self, call: FunctionCall) -> FunctionCallResult: return FunctionCallResult(is_model_turn=f.after == ChatRole.ASSISTANT, message=msg) async def handle_function_call_exception( - self, call: FunctionCall, err: FunctionCallException, attempt: int + self, call: FunctionCall, err: FunctionCallException, attempt: int, tool_call_id: str = None ) -> ExceptionHandleResult: """Called when a function call raises an exception. @@ -376,6 +388,7 @@ async def handle_function_call_exception( :param err: The error the call raised. Usually this is :class:`.NoSuchFunction` or :class:`.WrappedCallException`, although it may be any exception raised by :meth:`do_function_call`. :param attempt: The attempt number for the current call (0-indexed). + :param tool_call_id: The ``tool_call_id`` to set in the returned FUNCTION message. :returns: A :class:`.ExceptionHandleResult` detailing whether the model should retry and the message to add to the chat history. """ @@ -383,11 +396,15 @@ async def handle_function_call_exception( log.debug(f"Call to {call.name} raised an exception: {err}") # tell the model what went wrong if isinstance(err, NoSuchFunction): - msg = ChatMessage.system(f"The function {err.name!r} is not defined. Only use the provided functions.") + msg = ChatMessage.function( + name=None, + content=f"The function {err.name!r} is not defined. Only use the provided functions.", + tool_call_id=tool_call_id, + ) else: # but if it's a user function error, we want to raise it log.error(f"Call to {call.name} raised an exception: {err}", exc_info=err) - msg = ChatMessage.function(call.name, str(err)) + msg = ChatMessage.function(call.name, str(err), tool_call_id=tool_call_id) return ExceptionHandleResult(should_retry=attempt < self.retry_attempts and err.retry, message=msg) diff --git a/kani/models.py b/kani/models.py index 35d2a9d..75e6b99 100644 --- a/kani/models.py +++ b/kani/models.py @@ -3,6 +3,7 @@ import abc import enum import json +import uuid import warnings from typing import ClassVar, Sequence, Type, TypeAlias, Union @@ -47,7 +48,7 @@ class ChatRole(enum.Enum): class FunctionCall(BaseModel): - """Represents a model's request to call a function.""" + """Represents a model's request to call a single function.""" model_config = ConfigDict(frozen=True) @@ -68,6 +69,45 @@ def with_args(cls, name: str, **kwargs): return cls(name=name, arguments=json.dumps(kwargs)) +class ToolCall(BaseModel): + """Represents a model's request to call a tool with a unique request ID. + + See :ref:`functioncall_v_toolcall` for more information about tool calls vs function calls. + """ + + model_config = ConfigDict(frozen=True) + + id: str + """The request ID created by the engine. + This should be passed back to the engine in :attr:`.ChatMessage.tool_call_id` in order to associate a FUNCTION + message with this request. + """ + + type: str + """The type of tool requested (currently only "function").""" + + function: FunctionCall + """The requested function call.""" + + @classmethod + def from_function(cls, name: str, *, call_id_: str = None, **kwargs): + """Create a tool call request for a function with the given name and arguments. + + :param call_id_: The ID to assign to the request. If not passed, generates a random ID. + """ + call_id = call_id_ or str(uuid.uuid4()) + return cls(id=call_id, type="function", function=FunctionCall.with_args(name, **kwargs)) + + @classmethod + def from_function_call(cls, call: FunctionCall, call_id_: str = None): + """Create a tool call request from an existing FunctionCall. + + :param call_id_: The ID to assign to the request. If not passed, generates a random ID. + """ + call_id = call_id_ or str(uuid.uuid4()) + return cls(id=call_id, type="function", function=call) + + class MessagePart(BaseModel, abc.ABC): """Base class for a part of a message. Engines should inherit from this class to tag substrings with metadata or provide multimodality to an engine. @@ -146,6 +186,14 @@ class ChatMessage(BaseModel): model_config = ConfigDict(frozen=True) + def __init__(self, **kwargs): + # translate a function_call into tool_calls + if "function_call" in kwargs: + if "tool_calls" in kwargs: + raise ValueError("Only one of function_call or tool_calls may be provided.") + kwargs["tool_calls"] = (ToolCall.from_function_call(kwargs.pop("function_call")),) + super().__init__(**kwargs) + role: ChatRole """Who said the message?""" @@ -183,8 +231,27 @@ def parts(self) -> list[MessagePart | str]: name: str | None = None """The name of the user who sent the message, if set (user/function messages only).""" - function_call: FunctionCall | None = None - """The function requested by the model (assistant messages only).""" + tool_call_id: str | None = None + """The ID for a requested :class:`.ToolCall` which this message is a response to (function messages only).""" + + tool_calls: tuple[ToolCall] | None = None + """The tool calls requested by the model (assistant messages only).""" + + @property + def function_call(self) -> FunctionCall | None: + """If there is exactly one tool call to a function, return that tool call's requested function. + + This is mostly provided for backwards-compatibility purposes; iterating over :attr:`tool_calls` should be + preferred. + """ + if not self.tool_calls: + return None + if len(self.tool_calls) > 1: + warnings.warn( + "This message contains multiple tool calls; iterate over `.tool_calls` instead of using" + " `.function_call`." + ) + return self.tool_calls[0].function @classmethod def system(cls, content: str | Sequence[MessagePart | str], **kwargs): @@ -202,9 +269,9 @@ def assistant(cls, content: str | Sequence[MessagePart | str] | None, **kwargs): return cls(role=ChatRole.ASSISTANT, content=content, **kwargs) @classmethod - def function(cls, name: str, content: str | Sequence[MessagePart | str], **kwargs): + def function(cls, name: str | None, content: str | Sequence[MessagePart | str], tool_call_id: str = None, **kwargs): """Create a new function message.""" - return cls(role=ChatRole.FUNCTION, content=content, name=name, **kwargs) + return cls(role=ChatRole.FUNCTION, content=content, name=name, tool_call_id=tool_call_id, **kwargs) def copy_with(self, **new_values): """Make a shallow copy of this object, updating the passed attributes (if any) to new values. @@ -213,7 +280,10 @@ def copy_with(self, **new_values): This is mostly just a convenience wrapper around ``.model_copy``. Only one of (content, text, parts) may be passed and will update the other two attributes accordingly. + + Only one of (tool_calls, function_call) may be passed and will update the other accordingly. """ + # === content === # ensure the content is immutable if "content" in new_values and not isinstance(new_values["content"], str): new_values["content"] = tuple(new_values["content"]) @@ -226,4 +296,9 @@ def copy_with(self, **new_values): if "content" in new_values: raise ValueError("At most one of ('content', 'text', 'parts') can be set.") new_values["content"] = tuple(new_values.pop("parts")) + # === tool calls === + if "function_call" in new_values: + if "tool_calls" in new_values: + raise ValueError("Only one of function_call or tool_calls may be provided.") + new_values["tool_calls"] = (ToolCall.from_function_call(new_values.pop("function_call")),) return super().copy_with(**new_values) diff --git a/tests/test_chatmessage.py b/tests/test_chatmessage.py index 8037e68..01650db 100644 --- a/tests/test_chatmessage.py +++ b/tests/test_chatmessage.py @@ -1,7 +1,7 @@ import pydantic import pytest -from kani import ChatMessage, ChatRole, MessagePart +from kani import ChatMessage, ChatRole, MessagePart, ToolCall class TestMessagePart(MessagePart): @@ -41,7 +41,7 @@ def test_parts(): assert msg.parts == ["Hello world", part] -def test_copy(): +def test_copy_parts(): part = TestMessagePart() msg = ChatMessage(role=ChatRole.USER, content=["Hello world", part]) @@ -58,3 +58,19 @@ def test_copy(): with pytest.raises(ValueError): msg.copy_with(text="foo", parts=[]) + + +def test_copy_tools(): + msg = ChatMessage(role=ChatRole.ASSISTANT, content=None) + + bar_call = ToolCall.from_function("bar") + calls_copy = msg.copy_with(tool_calls=[bar_call]) + assert calls_copy.tool_calls == [bar_call] + assert calls_copy.function_call == bar_call.function + + func_copy = msg.copy_with(function_call=bar_call.function) + assert len(func_copy.tool_calls) == 1 + assert func_copy.function_call == bar_call.function + + with pytest.raises(ValueError): + msg.copy_with(tool_calls=[bar_call], function_call=bar_call.function) diff --git a/tests/test_saveload.py b/tests/test_saveload.py index 576cd43..cafc2e4 100644 --- a/tests/test_saveload.py +++ b/tests/test_saveload.py @@ -3,7 +3,7 @@ import random import string -from kani import ChatMessage, Kani, MessagePart +from kani import ChatMessage, FunctionCall, Kani, MessagePart from tests.engine import TestEngine engine = TestEngine() @@ -28,6 +28,34 @@ async def test_saveload_str(tmp_path): assert ai.chat_history == loaded.chat_history +async def test_saveload_tool_calls(tmp_path): + """Test that tool calls are saved.""" + fewshot = [ + ChatMessage.user("What's the weather in Philadelphia?"), + ChatMessage.assistant( + content=None, + function_call=FunctionCall.with_args("get_weather", location="Philadelphia, PA", unit="fahrenheit"), + ), + ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit."), + ChatMessage.assistant( + content=None, + function_call=FunctionCall.with_args("get_weather", location="Philadelphia, PA", unit="celsius"), + ), + ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius."), + ChatMessage.assistant("It's currently 85F (29C) and partly cloudy in Philadelphia."), + ] + ai = Kani(engine, chat_history=fewshot) + + # save and load + ai.save(tmp_path / "pytest.json") + loaded = Kani(engine) + loaded.load(tmp_path / "pytest.json") + + # assert equality + assert ai.always_included_messages == loaded.always_included_messages + assert ai.chat_history == loaded.chat_history + + class TestMessagePart1(MessagePart): data: str