diff --git a/kani/engines/openai/client.py b/kani/engines/openai/client.py index 8b5d8ee..cfc03b7 100644 --- a/kani/engines/openai/client.py +++ b/kani/engines/openai/client.py @@ -167,7 +167,7 @@ async def create_chat_completion( "/chat/completions", json={ "model": model, - "messages": [cm.model_dump(exclude_defaults=True, mode="json") for cm in messages], + "messages": [cm.model_dump(exclude_none=True, mode="json") for cm in messages], **kwargs, }, ) diff --git a/kani/engines/openai/engine.py b/kani/engines/openai/engine.py index 933604e..fca0074 100644 --- a/kani/engines/openai/engine.py +++ b/kani/engines/openai/engine.py @@ -110,17 +110,18 @@ def message_len(self, message: ChatMessage) -> int: mlen += len(self.tokenizer.encode(message.function_call.arguments)) return mlen - async def predict( - self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams - ) -> ChatCompletion: - if functions: - tool_specs = [ - ToolSpec.from_function(FunctionSpec(name=f.name, description=f.desc, parameters=f.json_schema)) - for f in functions - ] - else: - tool_specs = None - # translate to openai spec - group any tool messages together and ensure all free ToolCall IDs are bound + # translation helpers + @staticmethod + def translate_functions(functions: list[AIFunction], cls: type[ToolSpec] = ToolSpec) -> list[ToolSpec]: + return [ + cls.from_function(FunctionSpec(name=f.name, description=f.desc, parameters=f.json_schema)) + for f in functions + ] + + @staticmethod + def translate_messages( + messages: list[ChatMessage], cls: type[OpenAIChatMessage] = OpenAIChatMessage + ) -> list[OpenAIChatMessage]: translated_messages = [] free_toolcall_ids = set() for m in messages: @@ -148,11 +149,22 @@ async def predict( "Got a FUNCTION message with no tool_call_id but multiple tool calls are pending" f" ({free_toolcall_ids})! Set the tool_call_id to resolve the pending tool requests." ) - translated_messages.append(OpenAIChatMessage.from_chatmessage(m)) + translated_messages.append(cls.from_chatmessage(m)) # if the translated messages start with a hanging TOOL call, strip it (openai limitation) # though hanging FUNCTION messages are OK while translated_messages and translated_messages[0].role == "tool": translated_messages.pop(0) + return translated_messages + + async def predict( + self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams + ) -> ChatCompletion: + if functions: + tool_specs = self.translate_functions(functions) + else: + tool_specs = None + # translate to openai spec - group any tool messages together and ensure all free ToolCall IDs are bound + translated_messages = self.translate_messages(messages) # make API call completion = await self.client.create_chat_completion( model=self.model, messages=translated_messages, tools=tool_specs, **self.hyperparams, **hyperparams