From 85076534a16e0499a75765fd7534fd64a1d734d0 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Tue, 7 Nov 2023 15:32:11 -0500 Subject: [PATCH 1/9] refactor: tool calls --- docs/api_reference.rst | 5 ++ docs/function_calling.rst | 15 ++++++ kani/engines/openai/client.py | 38 ++++++++++++--- kani/engines/openai/engine.py | 50 +++++++++++++++++--- kani/engines/openai/models.py | 70 +++++++++++++++++++++++++--- kani/exceptions.py | 4 ++ kani/models.py | 87 ++++++++++++++++++++++++++++++++--- kani/utils/deprecation.py | 73 +++++++++++++++++++++++++++++ tests/test_chatmessage.py | 19 +++++++- tests/test_saveload.py | 30 +++++++++++- 10 files changed, 363 insertions(+), 28 deletions(-) create mode 100644 kani/utils/deprecation.py diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 5456ad1..62c26cb 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -16,6 +16,11 @@ Common Models :exclude-members: model_config, model_fields :class-doc-from: class +.. autoclass:: kani.ToolCall + :members: + :exclude-members: model_config, model_fields + :class-doc-from: class + .. autoclass:: kani.MessagePart :members: :exclude-members: model_config, model_fields diff --git a/docs/function_calling.rst b/docs/function_calling.rst index be5a847..8a1cade 100644 --- a/docs/function_calling.rst +++ b/docs/function_calling.rst @@ -255,3 +255,18 @@ error in a message to the model by default, allowing it up to *retry_attempts* t call. In the next section, we'll discuss how to customize this behaviour, along with other parts of the kani interface. + +.. _functioncall_v_toolcall: + +Internal Representation +----------------------- + +.. versionchanged:: v0.6.0 + +As of Nov 6, 2023, OpenAI added the ability for a single assistant message to request calling multiple functions in +parallel, and wrapped all function calls in a :class:`.ToolCall` wrapper. In order to add support for this in kani while +maintaining backwards compatibility with OSS function calling models, a :class:`.ChatMessage` actually maintains the +following internal representation: + +:attr:`.ChatMessage.function_call` is actually an alias for ``ChatMessage.tool_calls[0].function``. If there is more +than one tool call in the message, kani will raise an exception. diff --git a/kani/engines/openai/client.py b/kani/engines/openai/client.py index fe6e8cb..8b5d8ee 100644 --- a/kani/engines/openai/client.py +++ b/kani/engines/openai/client.py @@ -1,10 +1,20 @@ import asyncio +import warnings from typing import Literal, overload import aiohttp import pydantic -from .models import ChatCompletion, Completion, FunctionSpec, OpenAIChatMessage, SpecificFunctionCall +from .models import ( + ChatCompletion, + Completion, + FunctionSpec, + OpenAIChatMessage, + ResponseFormat, + SpecificFunctionCall, + ToolChoice, + ToolSpec, +) from ..httpclient import BaseClient, HTTPException, HTTPStatusException, HTTPTimeout @@ -77,6 +87,7 @@ async def request(self, method: str, route: str, headers=None, retry=None, **kwa async def create_completion( self, model: str, + *, prompt: str = "<|endoftext|>", suffix: str = None, max_tokens: int = 16, @@ -85,6 +96,7 @@ async def create_completion( n: int = 1, logprobs: int = None, echo: bool = False, + seed: int | None = None, stop: str | list[str] = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, @@ -107,24 +119,31 @@ async def create_chat_completion( self, model: str, messages: list[OpenAIChatMessage], - functions: list[FunctionSpec] | None = None, - function_call: SpecificFunctionCall | Literal["auto"] | Literal["none"] | None = None, + *, + tools: list[ToolSpec] | None = None, + tool_choice: ToolChoice | Literal["auto"] | Literal["none"] | None = None, temperature: float = 1.0, top_p: float = 1.0, n: int = 1, + response_format: ResponseFormat | None = None, + seed: int | None = None, stop: str | list[str] | None = None, max_tokens: int | None = None, presence_penalty: float = 0.0, frequency_penalty: float = 0.0, logit_bias: dict | None = None, user: str | None = None, + # deprecated + functions: list[FunctionSpec] | None = None, + function_call: SpecificFunctionCall | Literal["auto"] | Literal["none"] | None = None, ) -> ChatCompletion: ... async def create_chat_completion( self, model: str, messages: list[OpenAIChatMessage], - functions: list[FunctionSpec] | None = None, + *, + tools: list[ToolSpec] | None = None, **kwargs, ) -> ChatCompletion: """Create a chat completion. @@ -132,9 +151,16 @@ async def create_chat_completion( See https://platform.openai.com/docs/api-reference/chat/create. """ # transform pydantic models - if functions: - kwargs["functions"] = [f.model_dump(exclude_unset=True) for f in functions] + if tools: + kwargs["tools"] = [t.model_dump(exclude_unset=True) for t in tools] + if "tool_choice" in kwargs and isinstance(kwargs["tool_choice"], SpecificFunctionCall): + kwargs["tool_choice"] = kwargs["tool_choice"].model_dump(exclude_unset=True) + # deprecated function calling + if "functions" in kwargs: + warnings.warn("The functions parameter is deprecated. Use tools instead.", DeprecationWarning) + kwargs["functions"] = [f.model_dump(exclude_unset=True) for f in kwargs["functions"]] if "function_call" in kwargs and isinstance(kwargs["function_call"], SpecificFunctionCall): + warnings.warn("The function_call parameter is deprecated. Use tool_choice instead.", DeprecationWarning) kwargs["function_call"] = kwargs["function_call"].model_dump(exclude_unset=True) # call API data = await self.post( diff --git a/kani/engines/openai/engine.py b/kani/engines/openai/engine.py index 9e068df..933604e 100644 --- a/kani/engines/openai/engine.py +++ b/kani/engines/openai/engine.py @@ -2,11 +2,11 @@ import os from kani.ai_function import AIFunction -from kani.exceptions import MissingModelDependencies -from kani.models import ChatMessage +from kani.exceptions import MissingModelDependencies, ToolCallError +from kani.models import ChatMessage, ChatRole from . import function_calling from .client import OpenAIClient -from .models import ChatCompletion, FunctionSpec, OpenAIChatMessage +from .models import ChatCompletion, FunctionSpec, OpenAIChatMessage, ToolSpec from ..base import BaseEngine try: @@ -114,12 +114,48 @@ async def predict( self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams ) -> ChatCompletion: if functions: - function_spec = [FunctionSpec(name=f.name, description=f.desc, parameters=f.json_schema) for f in functions] + tool_specs = [ + ToolSpec.from_function(FunctionSpec(name=f.name, description=f.desc, parameters=f.json_schema)) + for f in functions + ] else: - function_spec = None - translated_messages = [OpenAIChatMessage.from_chatmessage(m) for m in messages] + tool_specs = None + # translate to openai spec - group any tool messages together and ensure all free ToolCall IDs are bound + translated_messages = [] + free_toolcall_ids = set() + for m in messages: + # if this is not a function result and there are free tool call IDs, raise + if m.role != ChatRole.FUNCTION and free_toolcall_ids: + raise ToolCallError( + f"Encountered a {m.role.value!r} message but expected a FUNCTION message to satisfy the pending" + f" tool call(s): {free_toolcall_ids}" + ) + # asst: add tool call IDs to freevars + if m.role == ChatRole.ASSISTANT and m.tool_calls: + for tc in m.tool_calls: + free_toolcall_ids.add(tc.id) + # func: bind freevars + elif m.role == ChatRole.FUNCTION: + # has ID: bind it + if m.tool_call_id is not None and m.tool_call_id in free_toolcall_ids: + free_toolcall_ids.remove(m.tool_call_id) + # no ID: bind if unambiguous, otherwise cry + elif m.tool_call_id is None: + if len(free_toolcall_ids) == 1: + m = m.copy_with(tool_call_id=free_toolcall_ids.pop()) + elif len(free_toolcall_ids) > 1: + raise ToolCallError( + "Got a FUNCTION message with no tool_call_id but multiple tool calls are pending" + f" ({free_toolcall_ids})! Set the tool_call_id to resolve the pending tool requests." + ) + translated_messages.append(OpenAIChatMessage.from_chatmessage(m)) + # if the translated messages start with a hanging TOOL call, strip it (openai limitation) + # though hanging FUNCTION messages are OK + while translated_messages and translated_messages[0].role == "tool": + translated_messages.pop(0) + # make API call completion = await self.client.create_chat_completion( - model=self.model, messages=translated_messages, functions=function_spec, **self.hyperparams, **hyperparams + model=self.model, messages=translated_messages, tools=tool_specs, **self.hyperparams, **hyperparams ) return completion diff --git a/kani/engines/openai/models.py b/kani/engines/openai/models.py index e8dacf7..69fa5af 100644 --- a/kani/engines/openai/models.py +++ b/kani/engines/openai/models.py @@ -1,6 +1,6 @@ from typing import Literal -from kani.models import BaseModel, ChatMessage, ChatRole, FunctionCall +from kani.models import BaseModel, ChatMessage, ChatRole, FunctionCall, ToolCall from ..base import BaseCompletion @@ -30,6 +30,7 @@ class Completion(BaseModel): object: Literal["text_completion"] created: int model: str + system_fingerprint: str | None = None choices: list[CompletionChoice] usage: CompletionUsage @@ -46,25 +47,79 @@ class FunctionSpec(BaseModel): parameters: dict +class ToolSpec(BaseModel): + type: str + function: FunctionSpec + + @classmethod + def from_function(cls, spec: FunctionSpec): + return cls(type="function", function=spec) + + class SpecificFunctionCall(BaseModel): name: str +class ToolChoice(BaseModel): + type: str + function: SpecificFunctionCall + + @classmethod + def from_function(cls, name: str): + return cls(type="function", function=SpecificFunctionCall(name=name)) + + +class ResponseFormat(BaseModel): + type: str + + @classmethod + def text(cls): + return cls(type="text") + + @classmethod + def json_object(cls): + return cls(type="json_object") + + class OpenAIChatMessage(BaseModel): - role: ChatRole + role: str content: str | list[BaseModel | str] | None name: str | None = None + tool_call_id: str | None = None + tool_calls: list[ToolCall] | None = None + # deprecated function_call: FunctionCall | None = None @classmethod def from_chatmessage(cls, m: ChatMessage): - return cls(role=m.role, content=m.text, name=m.name, function_call=m.function_call) + # translate tool responses to a function to the right openai format + if m.role == ChatRole.FUNCTION: + if m.tool_call_id is not None: + return cls(role="tool", content=m.text, name=m.name, tool_call_id=m.tool_call_id) + return cls(role=m.role.value, content=m.text, name=m.name) + return cls(role=m.role.value, content=m.text, name=m.name, tool_call_id=m.tool_call_id, tool_calls=m.tool_calls) + + def to_chatmessage(self) -> ChatMessage: + # translate tool role to function role + if self.role == "tool": + role = ChatRole.FUNCTION + else: + role = ChatRole(self.role) + # translate FunctionCall to singular ToolCall + if self.tool_calls: + tool_calls = self.tool_calls + elif self.function_call: + tool_calls = [ToolCall.from_function_call(self.function_call)] + else: + tool_calls = None + return ChatMessage( + role=role, content=self.content, name=self.name, tool_call_id=self.tool_call_id, tool_calls=tool_calls + ) # ---- response ---- class ChatCompletionChoice(BaseModel): - # this is a ChatMessage rather than an OpenAIChatMessage because all engines need to return the kani model - message: ChatMessage + message: OpenAIChatMessage index: int finish_reason: str | None = None @@ -74,12 +129,13 @@ class ChatCompletion(BaseCompletion, BaseModel): object: Literal["chat.completion"] created: int model: str + system_fingerprint: str | None = None usage: CompletionUsage choices: list[ChatCompletionChoice] @property - def message(self): - return self.choices[0].message + def message(self) -> ChatMessage: + return self.choices[0].message.to_chatmessage() @property def prompt_tokens(self): diff --git a/kani/exceptions.py b/kani/exceptions.py index fffc6a3..23531d9 100644 --- a/kani/exceptions.py +++ b/kani/exceptions.py @@ -63,6 +63,10 @@ class MissingModelDependencies(KaniException): """You are trying to use an engine but do not have engine-specific packages installed.""" +class ToolCallError(KaniException): + """Something went wrong with tool calls.""" + + # ==== serdes ==== class MissingMessagePartType(KaniException): """During loading a saved kani, a message part has a type which is not currently defined in the runtime.""" diff --git a/kani/models.py b/kani/models.py index 35d2a9d..ccbb77e 100644 --- a/kani/models.py +++ b/kani/models.py @@ -3,12 +3,13 @@ import abc import enum import json +import uuid import warnings from typing import ClassVar, Sequence, Type, TypeAlias, Union from pydantic import BaseModel as PydanticBase, ConfigDict, model_serializer, model_validator -from .exceptions import MissingMessagePartType +from .exceptions import MissingMessagePartType, ToolCallError # ==== constants ==== MESSAGEPART_TYPE_KEY = "__kani_messagepart_type__" # used for serdes of MessageParts @@ -47,7 +48,7 @@ class ChatRole(enum.Enum): class FunctionCall(BaseModel): - """Represents a model's request to call a function.""" + """Represents a model's request to call a single function.""" model_config = ConfigDict(frozen=True) @@ -68,6 +69,45 @@ def with_args(cls, name: str, **kwargs): return cls(name=name, arguments=json.dumps(kwargs)) +class ToolCall(BaseModel): + """Represents a model's request to call a tool with a unique request ID. + + See :ref:`functioncall_v_toolcall` for more information about tool calls vs function calls. + """ + + model_config = ConfigDict(frozen=True) + + id: str + """The request ID created by the engine. + This should be passed back to the engine in :attr:`.ChatMessage.tool_call_id` in order to associate a TOOL message + with this request. + """ + + type: str + """The type of tool requested (currently only "function").""" + + function: FunctionCall + """The requested function call.""" + + @classmethod + def from_function(cls, name: str, *, call_id_: str = None, **kwargs): + """Create a tool call request for a function with the given name and arguments. + + :param call_id_: The ID to assign to the request. If not passed, generates a random ID. + """ + call_id = call_id_ or str(uuid.uuid4()) + return cls(id=call_id, type="function", function=FunctionCall.with_args(name, **kwargs)) + + @classmethod + def from_function_call(cls, call: FunctionCall, call_id_: str = None): + """Create a tool call request from an existing FunctionCall. + + :param call_id_: The ID to assign to the request. If not passed, generates a random ID. + """ + call_id = call_id_ or str(uuid.uuid4()) + return cls(id=call_id, type="function", function=call) + + class MessagePart(BaseModel, abc.ABC): """Base class for a part of a message. Engines should inherit from this class to tag substrings with metadata or provide multimodality to an engine. @@ -146,6 +186,14 @@ class ChatMessage(BaseModel): model_config = ConfigDict(frozen=True) + def __init__(self, **kwargs): + # translate a function_call into tool_calls + if "function_call" in kwargs: + if "tool_calls" in kwargs: + raise ValueError("Only one of function_call or tool_calls may be provided.") + kwargs["tool_calls"] = (ToolCall.from_function_call(kwargs.pop("function_call")),) + super().__init__(**kwargs) + role: ChatRole """Who said the message?""" @@ -183,8 +231,27 @@ def parts(self) -> list[MessagePart | str]: name: str | None = None """The name of the user who sent the message, if set (user/function messages only).""" - function_call: FunctionCall | None = None - """The function requested by the model (assistant messages only).""" + tool_call_id: str | None = None + """The ID for a requested :class:`.ToolCall` which this message is a response to (function messages only).""" + + tool_calls: tuple[ToolCall] | None = None + """The tool calls requested by the model (assistant messages only).""" + + @property + def function_call(self) -> FunctionCall | None: + """If there is exactly one tool call to a function, return that tool call's requested function. + + This is mostly provided for backwards-compatibility purposes; iterating over :attr:`tool_calls` should be + preferred. + """ + if not self.tool_calls: + return None + if len(self.tool_calls) > 1: + raise ToolCallError( + "This message contains multiple tool calls; iterate over `.tool_calls` instead of using" + " `.function_call`." + ) + return self.tool_calls[0].function @classmethod def system(cls, content: str | Sequence[MessagePart | str], **kwargs): @@ -202,9 +269,9 @@ def assistant(cls, content: str | Sequence[MessagePart | str] | None, **kwargs): return cls(role=ChatRole.ASSISTANT, content=content, **kwargs) @classmethod - def function(cls, name: str, content: str | Sequence[MessagePart | str], **kwargs): + def function(cls, name: str, content: str | Sequence[MessagePart | str], tool_call_id: str = None, **kwargs): """Create a new function message.""" - return cls(role=ChatRole.FUNCTION, content=content, name=name, **kwargs) + return cls(role=ChatRole.FUNCTION, content=content, name=name, tool_call_id=tool_call_id, **kwargs) def copy_with(self, **new_values): """Make a shallow copy of this object, updating the passed attributes (if any) to new values. @@ -213,7 +280,10 @@ def copy_with(self, **new_values): This is mostly just a convenience wrapper around ``.model_copy``. Only one of (content, text, parts) may be passed and will update the other two attributes accordingly. + + Only one of (tool_calls, function_call) may be passed and will update the other accordingly. """ + # === content === # ensure the content is immutable if "content" in new_values and not isinstance(new_values["content"], str): new_values["content"] = tuple(new_values["content"]) @@ -226,4 +296,9 @@ def copy_with(self, **new_values): if "content" in new_values: raise ValueError("At most one of ('content', 'text', 'parts') can be set.") new_values["content"] = tuple(new_values.pop("parts")) + # === tool calls === + if "function_call" in new_values: + if "tool_calls" in new_values: + raise ValueError("Only one of function_call or tool_calls may be provided.") + new_values["tool_calls"] = (ToolCall.from_function_call(new_values.pop("function_call")),) return super().copy_with(**new_values) diff --git a/kani/utils/deprecation.py b/kani/utils/deprecation.py new file mode 100644 index 0000000..991ff28 --- /dev/null +++ b/kani/utils/deprecation.py @@ -0,0 +1,73 @@ +import functools +import inspect +import warnings + +string_types = (type(b""), type("")) + + +# deprecated wrapper from https://stackoverflow.com/a/40301488 +def deprecated(reason): + """ + This is a decorator which can be used to mark functions + as deprecated. It will result in a warning being emitted + when the function is used. + """ + + if isinstance(reason, string_types): + + # The @deprecated is used with a 'reason'. + # + # .. code-block:: python + # + # @deprecated("please, use another function") + # def old_function(x, y): + # pass + + def decorator(func1): + + if inspect.isclass(func1): + fmt1 = "Call to deprecated class {name} ({reason})." + else: + fmt1 = "Call to deprecated function {name} ({reason})." + + @functools.wraps(func1) + def new_func1(*args, **kwargs): + warnings.simplefilter("always", DeprecationWarning) + warnings.warn( + fmt1.format(name=func1.__name__, reason=reason), category=DeprecationWarning, stacklevel=2 + ) + warnings.simplefilter("default", DeprecationWarning) + return func1(*args, **kwargs) + + return new_func1 + + return decorator + + elif inspect.isclass(reason) or inspect.isfunction(reason): + + # The @deprecated is used without any 'reason'. + # + # .. code-block:: python + # + # @deprecated + # def old_function(x, y): + # pass + + func2 = reason + + if inspect.isclass(func2): + fmt2 = "Call to deprecated class {name}." + else: + fmt2 = "Call to deprecated function {name}." + + @functools.wraps(func2) + def new_func2(*args, **kwargs): + warnings.simplefilter("always", DeprecationWarning) + warnings.warn(fmt2.format(name=func2.__name__), category=DeprecationWarning, stacklevel=2) + warnings.simplefilter("default", DeprecationWarning) + return func2(*args, **kwargs) + + return new_func2 + + else: + raise TypeError(repr(type(reason))) diff --git a/tests/test_chatmessage.py b/tests/test_chatmessage.py index 8037e68..0f79223 100644 --- a/tests/test_chatmessage.py +++ b/tests/test_chatmessage.py @@ -2,6 +2,7 @@ import pytest from kani import ChatMessage, ChatRole, MessagePart +from kani.models import ToolCall class TestMessagePart(MessagePart): @@ -41,7 +42,7 @@ def test_parts(): assert msg.parts == ["Hello world", part] -def test_copy(): +def test_copy_parts(): part = TestMessagePart() msg = ChatMessage(role=ChatRole.USER, content=["Hello world", part]) @@ -58,3 +59,19 @@ def test_copy(): with pytest.raises(ValueError): msg.copy_with(text="foo", parts=[]) + + +def test_copy_tools(): + msg = ChatMessage(role=ChatRole.ASSISTANT, content=None) + + bar_call = ToolCall.from_function("bar") + calls_copy = msg.copy_with(tool_calls=[bar_call]) + assert calls_copy.tool_calls == [bar_call] + assert calls_copy.function_call == bar_call.function + + func_copy = msg.copy_with(function_call=bar_call.function) + assert len(func_copy.tool_calls) == 1 + assert func_copy.function_call == bar_call.function + + with pytest.raises(ValueError): + msg.copy_with(tool_calls=[bar_call], function_call=bar_call.function) diff --git a/tests/test_saveload.py b/tests/test_saveload.py index 576cd43..cafc2e4 100644 --- a/tests/test_saveload.py +++ b/tests/test_saveload.py @@ -3,7 +3,7 @@ import random import string -from kani import ChatMessage, Kani, MessagePart +from kani import ChatMessage, FunctionCall, Kani, MessagePart from tests.engine import TestEngine engine = TestEngine() @@ -28,6 +28,34 @@ async def test_saveload_str(tmp_path): assert ai.chat_history == loaded.chat_history +async def test_saveload_tool_calls(tmp_path): + """Test that tool calls are saved.""" + fewshot = [ + ChatMessage.user("What's the weather in Philadelphia?"), + ChatMessage.assistant( + content=None, + function_call=FunctionCall.with_args("get_weather", location="Philadelphia, PA", unit="fahrenheit"), + ), + ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit."), + ChatMessage.assistant( + content=None, + function_call=FunctionCall.with_args("get_weather", location="Philadelphia, PA", unit="celsius"), + ), + ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius."), + ChatMessage.assistant("It's currently 85F (29C) and partly cloudy in Philadelphia."), + ] + ai = Kani(engine, chat_history=fewshot) + + # save and load + ai.save(tmp_path / "pytest.json") + loaded = Kani(engine) + loaded.load(tmp_path / "pytest.json") + + # assert equality + assert ai.always_included_messages == loaded.always_included_messages + assert ai.chat_history == loaded.chat_history + + class TestMessagePart1(MessagePart): data: str From 1050a524c0491772c1734da3fbaa1222cf796b33 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Tue, 7 Nov 2023 15:55:24 -0500 Subject: [PATCH 2/9] refactor!: Kani.do_function_call takes tool_call_id --- docs/customization/function_call.rst | 4 +- docs/function_calling.rst | 103 ++++++++++++------ examples/2_function_calling_fewshot.py | 13 ++- examples/3_customization_exception_prompt.py | 4 +- .../3_customization_track_function_calls.py | 4 +- examples/colab_examples.ipynb | 12 +- kani/kani.py | 19 +++- kani/models.py | 2 +- 8 files changed, 105 insertions(+), 56 deletions(-) diff --git a/docs/customization/function_call.rst b/docs/customization/function_call.rst index 955dba1..35aefd2 100644 --- a/docs/customization/function_call.rst +++ b/docs/customization/function_call.rst @@ -39,9 +39,9 @@ during a conversation, and how often it was successful: self.successful_calls = collections.Counter() self.failed_calls = collections.Counter() - async def do_function_call(self, call): + async def do_function_call(self, call, *args, **kwargs): try: - result = await super().do_function_call(call) + result = await super().do_function_call(call, *args, **kwargs) self.successful_calls[call.name] += 1 return result except FunctionCallException: diff --git a/docs/function_calling.rst b/docs/function_calling.rst index 8a1cade..a55f748 100644 --- a/docs/function_calling.rst +++ b/docs/function_calling.rst @@ -182,38 +182,77 @@ prompt a model, we can mock these returns in the chat history using :meth:`.Chat For example, here's how you might prompt the model to give the temperature in both Fahrenheit and Celsius without the user having to ask: -.. code-block:: python - - from kani import ChatMessage, FunctionCall - fewshot = [ - ChatMessage.user("What's the weather in Philadelphia?"), - # first, the model should ask for the weather in fahrenheit - ChatMessage.assistant( - content=None, - function_call=FunctionCall.with_args( - "get_weather", location="Philadelphia, PA", unit="fahrenheit" - ) - ), - # and we mock the function's response to the model - ChatMessage.function( - "get_weather", - "Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit.", - ), - # repeat in celsius - ChatMessage.assistant( - content=None, - function_call=FunctionCall.with_args( - "get_weather", location="Philadelphia, PA", unit="celsius" - ) - ), - ChatMessage.function( - "get_weather", - "Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius.", - ), - # finally, give the result to the user - ChatMessage.assistant("It's currently 85F (29C) and partly cloudy in Philadelphia."), - ] - ai = MyKani(engine, chat_history=fewshot) +.. tab:: ToolCall API + + .. code-block:: python + + # build the chat history with examples + fewshot = [ + ChatMessage.user("What's the weather in Philadelphia?"), + ChatMessage.assistant( + content=None, + # use a walrus operator to save a reference to the tool call here... + tool_calls=[ + tc := ToolCall.from_function("get_weather", location="Philadelphia, PA", unit="fahrenheit") + ], + ), + ChatMessage.function( + "get_weather", + "Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit.", + # ...so this function result knows which call it's responding to + tc.id + ), + # and repeat for the other unit + ChatMessage.assistant( + content=None, + tool_calls=[ + tc2 := ToolCall.from_function("get_weather", location="Philadelphia, PA", unit="celsius") + ], + ), + ChatMessage.function( + "get_weather", + "Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius.", + tc2.id + ), + ChatMessage.assistant("It's currently 85F (29C) and partly cloudy in Philadelphia."), + ] + # and give it to the kani when you initialize it + ai = MyKani(engine, chat_history=fewshot) + +.. tab:: FunctionCall API (deprecated) + + .. code-block:: python + + from kani import ChatMessage, FunctionCall + fewshot = [ + ChatMessage.user("What's the weather in Philadelphia?"), + # first, the model should ask for the weather in fahrenheit + ChatMessage.assistant( + content=None, + function_call=FunctionCall.with_args( + "get_weather", location="Philadelphia, PA", unit="fahrenheit" + ) + ), + # and we mock the function's response to the model + ChatMessage.function( + "get_weather", + "Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit.", + ), + # repeat in celsius + ChatMessage.assistant( + content=None, + function_call=FunctionCall.with_args( + "get_weather", location="Philadelphia, PA", unit="celsius" + ) + ), + ChatMessage.function( + "get_weather", + "Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius.", + ), + # finally, give the result to the user + ChatMessage.assistant("It's currently 85F (29C) and partly cloudy in Philadelphia."), + ] + ai = MyKani(engine, chat_history=fewshot) .. code-block:: pycon diff --git a/examples/2_function_calling_fewshot.py b/examples/2_function_calling_fewshot.py index d1f5b29..672464a 100644 --- a/examples/2_function_calling_fewshot.py +++ b/examples/2_function_calling_fewshot.py @@ -2,8 +2,9 @@ import os from typing import Annotated -from kani import AIParam, ChatMessage, FunctionCall, Kani, ai_function, chat_in_terminal +from kani import AIParam, ChatMessage, Kani, ai_function, chat_in_terminal from kani.engines.openai import OpenAIEngine +from kani.models import ToolCall api_key = os.getenv("OPENAI_API_KEY") engine = OpenAIEngine(api_key, model="gpt-3.5-turbo") @@ -35,14 +36,16 @@ def get_weather( ChatMessage.user("What's the weather in Philadelphia?"), ChatMessage.assistant( content=None, - function_call=FunctionCall.with_args("get_weather", location="Philadelphia, PA", unit="fahrenheit"), + # use a walrus operator to save a reference to the tool call here... + tool_calls=[tc := ToolCall.from_function("get_weather", location="Philadelphia, PA", unit="fahrenheit")], ), - ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit."), + # so this function result knows which call it's responding to + ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 85 degrees fahrenheit.", tc.id), ChatMessage.assistant( content=None, - function_call=FunctionCall.with_args("get_weather", location="Philadelphia, PA", unit="celsius"), + tool_calls=[tc2 := ToolCall.from_function("get_weather", location="Philadelphia, PA", unit="celsius")], ), - ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius."), + ChatMessage.function("get_weather", "Weather in Philadelphia, PA: Partly cloudy, 29 degrees celsius.", tc2.id), ChatMessage.assistant("It's currently 85F (29C) and partly cloudy in Philadelphia."), ] # and give it to the kani when you initialize it diff --git a/examples/3_customization_exception_prompt.py b/examples/3_customization_exception_prompt.py index 3d93212..bf330a6 100644 --- a/examples/3_customization_exception_prompt.py +++ b/examples/3_customization_exception_prompt.py @@ -13,9 +13,9 @@ class CustomExceptionPromptKani(Kani): - async def handle_function_call_exception(self, call, err, attempt): + async def handle_function_call_exception(self, call, err, attempt, *args, **kwargs): # get the standard retry logic... - result = await super().handle_function_call_exception(call, err, attempt) + result = await super().handle_function_call_exception(call, err, attempt, *args, **kwargs) # but override the returned message with our own result.message = ChatMessage.system( f"The call encountered an error. Relay this error message to the user in a sarcastic manner: {err}" diff --git a/examples/3_customization_track_function_calls.py b/examples/3_customization_track_function_calls.py index 1d05e66..72a0e55 100644 --- a/examples/3_customization_track_function_calls.py +++ b/examples/3_customization_track_function_calls.py @@ -22,9 +22,9 @@ def __init__(self, *args, **kwargs): self.successful_calls = collections.Counter() self.failed_calls = collections.Counter() - async def do_function_call(self, call): + async def do_function_call(self, call, *args, **kwargs): try: - result = await super().do_function_call(call) + result = await super().do_function_call(call, *args, **kwargs) self.successful_calls[call.name] += 1 return result except FunctionCallException: diff --git a/examples/colab_examples.ipynb b/examples/colab_examples.ipynb index acfc68f..ac170c4 100644 --- a/examples/colab_examples.ipynb +++ b/examples/colab_examples.ipynb @@ -380,9 +380,9 @@ "\n", "\n", "class CustomExceptionKani(Kani):\n", - " async def handle_function_call_exception(self, call, err, attempt):\n", + " async def handle_function_call_exception(self, call, err, attempt, *args, **kwargs):\n", " # get the standard retry logic...\n", - " result = await super().handle_function_call_exception(call, err, attempt)\n", + " result = await super().handle_function_call_exception(call, err, attempt, *args, **kwargs)\n", " # but override the returned message with our own\n", " result.message = ChatMessage.system(\n", " f\"The call encountered an error. Relay this error message to the user in a sarcastic manner: {err}\"\n", @@ -677,13 +677,13 @@ " self.successful_calls = collections.Counter()\n", " self.failed_calls = collections.Counter()\n", "\n", - " async def handle_function_call_exception(self, call, err, attempt):\n", - " msg = ChatMessage.system(str(err))\n", + " async def handle_function_call_exception(self, call, err, attempt, tool_call_id=None):\n", + " msg = ChatMessage.function(name=call.name, content=str(err), tool_call_id=tool_call_id)\n", " return ExceptionHandleResult(should_retry=attempt < self.retry_attempts, message=msg)\n", "\n", - " async def do_function_call(self, call):\n", + " async def do_function_call(self, call, *args, **kwargs):\n", " try:\n", - " res = await super().do_function_call(call)\n", + " res = await super().do_function_call(call, *args, **kwargs)\n", " self.successful_calls[call.name] += 1\n", " return res\n", " except FunctionCallException:\n", diff --git a/kani/kani.py b/kani/kani.py index 55269ab..6bf3965 100644 --- a/kani/kani.py +++ b/kani/kani.py @@ -178,7 +178,7 @@ async def full_round(self, query: QueryType, **kwargs) -> AsyncIterable[ChatMess yield message # if function call, do it and attempt retry if it's wrong - if not message.function_call: + if not message.tool_calls: return try: @@ -323,7 +323,7 @@ async def get_prompt(self) -> list[ChatMessage]: return self.always_included_messages return self.always_included_messages + self.chat_history[-to_keep:] - async def do_function_call(self, call: FunctionCall) -> FunctionCallResult: + async def do_function_call(self, call: FunctionCall, tool_call_id: str = None) -> FunctionCallResult: """Resolve a single function call. By default, any exception raised from this method will be an instance of a :class:`.FunctionCallException`. @@ -331,6 +331,8 @@ async def do_function_call(self, call: FunctionCall) -> FunctionCallResult: You may implement an override to add instrumentation around function calls (e.g. tracking success counts for varying prompts). See :doc:`/customization/function_call`. + :param call: The name of the function to call and arguments to call it with. + :param tool_call_id: The ``tool_call_id`` to set in the returned FUNCTION message. :returns: A :class:`.FunctionCallResult` including whose turn it is next and the message with the result of the function call. :raises NoSuchFunction: The requested function does not exist. @@ -348,7 +350,7 @@ async def do_function_call(self, call: FunctionCall) -> FunctionCallResult: log.debug(f"{f.name} responded with data: {result_str!r}") except Exception as e: raise WrappedCallException(f.auto_retry, e) from e - msg = ChatMessage.function(f.name, result_str) + msg = ChatMessage.function(f.name, result_str, tool_call_id=tool_call_id) # if we are auto truncating, check and see if we need to if f.auto_truncate is not None: message_len = self.message_token_len(msg) @@ -362,7 +364,7 @@ async def do_function_call(self, call: FunctionCall) -> FunctionCallResult: return FunctionCallResult(is_model_turn=f.after == ChatRole.ASSISTANT, message=msg) async def handle_function_call_exception( - self, call: FunctionCall, err: FunctionCallException, attempt: int + self, call: FunctionCall, err: FunctionCallException, attempt: int, tool_call_id: str = None ) -> ExceptionHandleResult: """Called when a function call raises an exception. @@ -376,6 +378,7 @@ async def handle_function_call_exception( :param err: The error the call raised. Usually this is :class:`.NoSuchFunction` or :class:`.WrappedCallException`, although it may be any exception raised by :meth:`do_function_call`. :param attempt: The attempt number for the current call (0-indexed). + :param tool_call_id: The ``tool_call_id`` to set in the returned FUNCTION message. :returns: A :class:`.ExceptionHandleResult` detailing whether the model should retry and the message to add to the chat history. """ @@ -383,11 +386,15 @@ async def handle_function_call_exception( log.debug(f"Call to {call.name} raised an exception: {err}") # tell the model what went wrong if isinstance(err, NoSuchFunction): - msg = ChatMessage.system(f"The function {err.name!r} is not defined. Only use the provided functions.") + msg = ChatMessage.function( + name=None, + content=f"The function {err.name!r} is not defined. Only use the provided functions.", + tool_call_id=tool_call_id, + ) else: # but if it's a user function error, we want to raise it log.error(f"Call to {call.name} raised an exception: {err}", exc_info=err) - msg = ChatMessage.function(call.name, str(err)) + msg = ChatMessage.function(call.name, str(err), tool_call_id=tool_call_id) return ExceptionHandleResult(should_retry=attempt < self.retry_attempts and err.retry, message=msg) diff --git a/kani/models.py b/kani/models.py index ccbb77e..c8be955 100644 --- a/kani/models.py +++ b/kani/models.py @@ -269,7 +269,7 @@ def assistant(cls, content: str | Sequence[MessagePart | str] | None, **kwargs): return cls(role=ChatRole.ASSISTANT, content=content, **kwargs) @classmethod - def function(cls, name: str, content: str | Sequence[MessagePart | str], tool_call_id: str = None, **kwargs): + def function(cls, name: str | None, content: str | Sequence[MessagePart | str], tool_call_id: str = None, **kwargs): """Create a new function message.""" return cls(role=ChatRole.FUNCTION, content=content, name=name, tool_call_id=tool_call_id, **kwargs) From b9a035f1a7b6ba63f5a0f8db2ecf98f48882da77 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Tue, 7 Nov 2023 15:55:36 -0500 Subject: [PATCH 3/9] docs: kwargs when overriding --- docs/customization/chat_history.rst | 4 ++-- docs/customization/function_exception.rst | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/customization/chat_history.rst b/docs/customization/chat_history.rst index d5ac244..25ba2aa 100644 --- a/docs/customization/chat_history.rst +++ b/docs/customization/chat_history.rst @@ -31,8 +31,8 @@ For example, here's how you might extend :meth:`.Kani.add_to_history` to log eve super().__init__(*args, **kwargs) self.log_file = open("kani-log.jsonl", "w") - async def add_to_history(self, message): - await super().add_to_history(message) + async def add_to_history(self, message, *args, **kwargs): + await super().add_to_history(message, *args, **kwargs) self.log_file.write(message.model_dump_json()) self.log_file.write("\n") diff --git a/docs/customization/function_exception.rst b/docs/customization/function_exception.rst index d8f3099..9536063 100644 --- a/docs/customization/function_exception.rst +++ b/docs/customization/function_exception.rst @@ -41,9 +41,9 @@ Here's an example of providing custom prompts on an exception: :emphasize-lines: 2-10 class CustomExceptionPromptKani(Kani): - async def handle_function_call_exception(self, call, err, attempt): + async def handle_function_call_exception(self, call, err, attempt, *args, **kwargs): # get the standard retry logic... - result = await super().handle_function_call_exception(call, err, attempt) + result = await super().handle_function_call_exception(call, err, attempt, *args, **kwargs) # but override the returned message with our own result.message = ChatMessage.system( "The call encountered an error. " From ee560103f276432d6b09424808b03824e9a3cc05 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Tue, 7 Nov 2023 17:00:47 -0500 Subject: [PATCH 4/9] refactor: handle parallel toolcalls --- docs/function_calling.rst | 8 +++++++ kani/internal.py | 10 +++++++-- kani/kani.py | 46 ++++++++++++++++++++++++--------------- 3 files changed, 44 insertions(+), 20 deletions(-) diff --git a/docs/function_calling.rst b/docs/function_calling.rst index a55f748..bf881c7 100644 --- a/docs/function_calling.rst +++ b/docs/function_calling.rst @@ -137,6 +137,10 @@ Next Actor After a function call returns, kani will hand control back to the LM to generate a response by default. If instead control should be given to the human (i.e. return from the chat round), set ``after=ChatRole.USER``. +.. note:: + If the model calls multiple tools in parallel, the model will be allowed to generate a response if *any* function + allows it. + Complete Example ---------------- Here's the full example of how you might implement a function to get weather that we built in the last few steps: @@ -293,6 +297,10 @@ passing params with invalid, non-coercible types) or the function raises an exce error in a message to the model by default, allowing it up to *retry_attempts* to correct itself and retry the call. +.. note:: + If the model calls multiple tools in parallel, the model will be allowed a retry if *any* exception handler + allows it. + In the next section, we'll discuss how to customize this behaviour, along with other parts of the kani interface. .. _functioncall_v_toolcall: diff --git a/kani/internal.py b/kani/internal.py index 79500fe..6d2faf6 100644 --- a/kani/internal.py +++ b/kani/internal.py @@ -1,7 +1,13 @@ +import abc + from .models import ChatMessage -class FunctionCallResult: +class HasMessage(abc.ABC): + message: ChatMessage + + +class FunctionCallResult(HasMessage): """A model requested a function call, and the kani runtime resolved it.""" def __init__(self, is_model_turn: bool, message: ChatMessage): @@ -13,7 +19,7 @@ def __init__(self, is_model_turn: bool, message: ChatMessage): self.message = message -class ExceptionHandleResult: +class ExceptionHandleResult(HasMessage): """A function call raised an exception, and the kani runtime has prompted the model with exception information.""" def __init__(self, should_retry: bool, message: ChatMessage): diff --git a/kani/kani.py b/kani/kani.py index 6bf3965..4899194 100644 --- a/kani/kani.py +++ b/kani/kani.py @@ -9,7 +9,7 @@ from .engines.base import BaseCompletion, BaseEngine from .exceptions import FunctionCallException, MessageTooLong, NoSuchFunction, WrappedCallException from .internal import ExceptionHandleResult, FunctionCallResult -from .models import ChatMessage, ChatRole, FunctionCall, QueryType +from .models import ChatMessage, ChatRole, FunctionCall, QueryType, ToolCall from .utils.message_formatters import assistant_message_contents from .utils.typing import PathLike, SavedKani @@ -181,27 +181,37 @@ async def full_round(self, query: QueryType, **kwargs) -> AsyncIterable[ChatMess if not message.tool_calls: return - try: - function_call_result = await self.do_function_call(message.function_call) - is_model_turn = function_call_result.is_model_turn - call_message = function_call_result.message + # run each tool call in parallel + async def _do_tool_call(tc: ToolCall): + try: + return await self.do_function_call(tc.function, tool_call_id=tc.id) + except FunctionCallException as e: + return await self.handle_function_call_exception(tc.function, e, retry, tool_call_id=tc.id) + + # and update results after they are completed + is_model_turn = False + should_retry_call = False + n_errs = 0 + results = await asyncio.gather(*(_do_tool_call(tc) for tc in message.tool_calls)) + for result in results: # save the result to the chat history - await self.add_to_history(call_message) - yield call_message - except FunctionCallException as e: - exception_handling_result = await self.handle_function_call_exception( - message.function_call, e, retry - ) - # save the result to the chat history - exc_message = exception_handling_result.message - await self.add_to_history(exc_message) - yield exc_message - # retry if we have retry attempts left + await self.add_to_history(result.message) + yield result.message + if isinstance(result, ExceptionHandleResult): + is_model_turn = True + n_errs += 1 + # retry if any function says so + should_retry_call = should_retry_call or result.should_retry + else: + # allow model to generate response if any function says so + is_model_turn = is_model_turn or result.is_model_turn + + # if we encountered an error, increment the retry counter and allow the model to generate a response + if n_errs: retry += 1 - if not exception_handling_result.should_retry: + if not should_retry_call: # disable function calling on the next go kwargs = {**kwargs, "include_functions": False} - continue else: retry = 0 From 7b69f0db5c2a64d4af9a04fbbb0b47497179c628 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Tue, 7 Nov 2023 17:04:28 -0500 Subject: [PATCH 5/9] chore: remove unused deprecation wrapper --- kani/utils/deprecation.py | 73 --------------------------------------- 1 file changed, 73 deletions(-) delete mode 100644 kani/utils/deprecation.py diff --git a/kani/utils/deprecation.py b/kani/utils/deprecation.py deleted file mode 100644 index 991ff28..0000000 --- a/kani/utils/deprecation.py +++ /dev/null @@ -1,73 +0,0 @@ -import functools -import inspect -import warnings - -string_types = (type(b""), type("")) - - -# deprecated wrapper from https://stackoverflow.com/a/40301488 -def deprecated(reason): - """ - This is a decorator which can be used to mark functions - as deprecated. It will result in a warning being emitted - when the function is used. - """ - - if isinstance(reason, string_types): - - # The @deprecated is used with a 'reason'. - # - # .. code-block:: python - # - # @deprecated("please, use another function") - # def old_function(x, y): - # pass - - def decorator(func1): - - if inspect.isclass(func1): - fmt1 = "Call to deprecated class {name} ({reason})." - else: - fmt1 = "Call to deprecated function {name} ({reason})." - - @functools.wraps(func1) - def new_func1(*args, **kwargs): - warnings.simplefilter("always", DeprecationWarning) - warnings.warn( - fmt1.format(name=func1.__name__, reason=reason), category=DeprecationWarning, stacklevel=2 - ) - warnings.simplefilter("default", DeprecationWarning) - return func1(*args, **kwargs) - - return new_func1 - - return decorator - - elif inspect.isclass(reason) or inspect.isfunction(reason): - - # The @deprecated is used without any 'reason'. - # - # .. code-block:: python - # - # @deprecated - # def old_function(x, y): - # pass - - func2 = reason - - if inspect.isclass(func2): - fmt2 = "Call to deprecated class {name}." - else: - fmt2 = "Call to deprecated function {name}." - - @functools.wraps(func2) - def new_func2(*args, **kwargs): - warnings.simplefilter("always", DeprecationWarning) - warnings.warn(fmt2.format(name=func2.__name__), category=DeprecationWarning, stacklevel=2) - warnings.simplefilter("default", DeprecationWarning) - return func2(*args, **kwargs) - - return new_func2 - - else: - raise TypeError(repr(type(reason))) From efbdfe5b00509ce39c768113c2c09a5c5dd6ad5b Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Tue, 7 Nov 2023 17:06:00 -0500 Subject: [PATCH 6/9] chore: kani.ToolCall --- kani/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kani/__init__.py b/kani/__init__.py index 20b748f..ebfbcb9 100644 --- a/kani/__init__.py +++ b/kani/__init__.py @@ -2,7 +2,7 @@ from .ai_function import AIFunction, AIParam, ai_function from .internal import ExceptionHandleResult, FunctionCallResult from .kani import Kani -from .models import ChatMessage, ChatRole, FunctionCall, MessagePart +from .models import ChatMessage, ChatRole, FunctionCall, MessagePart, ToolCall from .utils.cli import chat_in_terminal, chat_in_terminal_async # declare that kani is also a namespace package From 5dce06bc7a5d9393024ff25394028860854d2e58 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Tue, 7 Nov 2023 17:17:04 -0500 Subject: [PATCH 7/9] docs: fix doc on ToolCall.id --- kani/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/kani/models.py b/kani/models.py index c8be955..170c6d9 100644 --- a/kani/models.py +++ b/kani/models.py @@ -79,8 +79,8 @@ class ToolCall(BaseModel): id: str """The request ID created by the engine. - This should be passed back to the engine in :attr:`.ChatMessage.tool_call_id` in order to associate a TOOL message - with this request. + This should be passed back to the engine in :attr:`.ChatMessage.tool_call_id` in order to associate a FUNCTION + message with this request. """ type: str From 6a27a3ada104537bef71f23384c35f2840443b60 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Wed, 8 Nov 2023 11:32:37 -0500 Subject: [PATCH 8/9] docs: note on toolcalls in implementing engines --- docs/engines/implementing.rst | 9 ++++++++- docs/function_calling.rst | 3 +++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/docs/engines/implementing.rst b/docs/engines/implementing.rst index 470bcdd..d6a8670 100644 --- a/docs/engines/implementing.rst +++ b/docs/engines/implementing.rst @@ -48,6 +48,10 @@ You'll need to implement two methods: :meth:`.BaseEngine.predict` and :meth:`.Ba build such a prompt. :meth:`.BaseEngine.function_token_reserve` tells kani how many tokens that prompt takes, so the context window management can ensure it never sends too many tokens. +You'll also need to add previous function calls into the prompt (e.g. in the few-shot function calling example). +When you're building the prompt, you'll need to iterate over :attr:`.ChatMessage.tool_calls` if it exists, and add +your model's appropriate function calling prompt. + To parse the model's requests to call a function, you also do this in :meth:`.BaseEngine.predict`. After generating the model's completion (usually a string, or a list of token IDs that decodes into a string), separate the model's conversational content from the structured function call: @@ -56,4 +60,7 @@ conversational content from the structured function call: :align: center Finally, return a :class:`.Completion` with the ``.message`` attribute set to a :class:`.ChatMessage` with the -appropriate :attr:`.ChatMessage.content` and :attr:`.ChatMessage.function_call`. +appropriate :attr:`.ChatMessage.content` and :attr:`.ChatMessage.tool_calls`. + +.. note:: + See :ref:`functioncall_v_toolcall` for more information about ToolCalls vs FunctionCalls. diff --git a/docs/function_calling.rst b/docs/function_calling.rst index bf881c7..6cc1239 100644 --- a/docs/function_calling.rst +++ b/docs/function_calling.rst @@ -317,3 +317,6 @@ following internal representation: :attr:`.ChatMessage.function_call` is actually an alias for ``ChatMessage.tool_calls[0].function``. If there is more than one tool call in the message, kani will raise an exception. + +A ToolCall is effectively a named wrapper around a :class:`.FunctionCall`, associating the request with a generated +ID so that its response can be linked to the request in future rounds of prompting. From 6afe6483f513d2349b4730854cb633b93b00e8a3 Mon Sep 17 00:00:00 2001 From: Andrew Zhu Date: Wed, 8 Nov 2023 12:48:12 -0500 Subject: [PATCH 9/9] chore: PR comments --- docs/function_calling.rst | 4 ++-- examples/2_function_calling_fewshot.py | 3 +-- kani/kani.py | 4 ++-- kani/models.py | 4 ++-- tests/test_chatmessage.py | 3 +-- 5 files changed, 8 insertions(+), 10 deletions(-) diff --git a/docs/function_calling.rst b/docs/function_calling.rst index 6cc1239..94404f3 100644 --- a/docs/function_calling.rst +++ b/docs/function_calling.rst @@ -139,7 +139,7 @@ control should be given to the human (i.e. return from the chat round), set ``af .. note:: If the model calls multiple tools in parallel, the model will be allowed to generate a response if *any* function - allows it. + has ``after=ChatRole.ASSISTANT`` (the default) once all function calls are complete. Complete Example ---------------- @@ -299,7 +299,7 @@ call. .. note:: If the model calls multiple tools in parallel, the model will be allowed a retry if *any* exception handler - allows it. + allows it. This will only count as 1 retry attempt regardless of the number of functions that raised an exception. In the next section, we'll discuss how to customize this behaviour, along with other parts of the kani interface. diff --git a/examples/2_function_calling_fewshot.py b/examples/2_function_calling_fewshot.py index 672464a..981b00f 100644 --- a/examples/2_function_calling_fewshot.py +++ b/examples/2_function_calling_fewshot.py @@ -2,9 +2,8 @@ import os from typing import Annotated -from kani import AIParam, ChatMessage, Kani, ai_function, chat_in_terminal +from kani import AIParam, ChatMessage, Kani, ToolCall, ai_function, chat_in_terminal from kani.engines.openai import OpenAIEngine -from kani.models import ToolCall api_key = os.getenv("OPENAI_API_KEY") engine = OpenAIEngine(api_key, model="gpt-3.5-turbo") diff --git a/kani/kani.py b/kani/kani.py index 4899194..b6fe03a 100644 --- a/kani/kani.py +++ b/kani/kani.py @@ -80,7 +80,7 @@ def __init__( Use ``chat_history=mykani.chat_history.copy()`` to pass a copy. :param functions: A list of :class:`.AIFunction` to expose to the model (for dynamic function calling). Use :func:`.ai_function` to define static functions (see :doc:`function_calling`). - :param retry_attempts: How many attempts the LM may take if a function call raises an exception. + :param retry_attempts: How many attempts the LM may take per full round if any tool call raises an exception. """ self.engine = engine self.system_prompt = system_prompt.strip() if system_prompt else None @@ -211,7 +211,7 @@ async def _do_tool_call(tc: ToolCall): retry += 1 if not should_retry_call: # disable function calling on the next go - kwargs = {**kwargs, "include_functions": False} + kwargs["include_functions"] = False else: retry = 0 diff --git a/kani/models.py b/kani/models.py index 170c6d9..75e6b99 100644 --- a/kani/models.py +++ b/kani/models.py @@ -9,7 +9,7 @@ from pydantic import BaseModel as PydanticBase, ConfigDict, model_serializer, model_validator -from .exceptions import MissingMessagePartType, ToolCallError +from .exceptions import MissingMessagePartType # ==== constants ==== MESSAGEPART_TYPE_KEY = "__kani_messagepart_type__" # used for serdes of MessageParts @@ -247,7 +247,7 @@ def function_call(self) -> FunctionCall | None: if not self.tool_calls: return None if len(self.tool_calls) > 1: - raise ToolCallError( + warnings.warn( "This message contains multiple tool calls; iterate over `.tool_calls` instead of using" " `.function_call`." ) diff --git a/tests/test_chatmessage.py b/tests/test_chatmessage.py index 0f79223..01650db 100644 --- a/tests/test_chatmessage.py +++ b/tests/test_chatmessage.py @@ -1,8 +1,7 @@ import pydantic import pytest -from kani import ChatMessage, ChatRole, MessagePart -from kani.models import ToolCall +from kani import ChatMessage, ChatRole, MessagePart, ToolCall class TestMessagePart(MessagePart):