Skip to content

Commit

Permalink
chore(openai): break out translation to methods for easier extension
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Nov 8, 2023
1 parent c016148 commit 8e15a4b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
2 changes: 1 addition & 1 deletion kani/engines/openai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ async def create_chat_completion(
"/chat/completions",
json={
"model": model,
"messages": [cm.model_dump(exclude_defaults=True, mode="json") for cm in messages],
"messages": [cm.model_dump(exclude_none=True, mode="json") for cm in messages],
**kwargs,
},
)
Expand Down
36 changes: 24 additions & 12 deletions kani/engines/openai/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,18 @@ def message_len(self, message: ChatMessage) -> int:
mlen += len(self.tokenizer.encode(message.function_call.arguments))
return mlen

async def predict(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> ChatCompletion:
if functions:
tool_specs = [
ToolSpec.from_function(FunctionSpec(name=f.name, description=f.desc, parameters=f.json_schema))
for f in functions
]
else:
tool_specs = None
# translate to openai spec - group any tool messages together and ensure all free ToolCall IDs are bound
# translation helpers
@staticmethod
def translate_functions(functions: list[AIFunction], cls: type[ToolSpec] = ToolSpec) -> list[ToolSpec]:
return [
cls.from_function(FunctionSpec(name=f.name, description=f.desc, parameters=f.json_schema))
for f in functions
]

@staticmethod
def translate_messages(
messages: list[ChatMessage], cls: type[OpenAIChatMessage] = OpenAIChatMessage
) -> list[OpenAIChatMessage]:
translated_messages = []
free_toolcall_ids = set()
for m in messages:
Expand Down Expand Up @@ -148,11 +149,22 @@ async def predict(
"Got a FUNCTION message with no tool_call_id but multiple tool calls are pending"
f" ({free_toolcall_ids})! Set the tool_call_id to resolve the pending tool requests."
)
translated_messages.append(OpenAIChatMessage.from_chatmessage(m))
translated_messages.append(cls.from_chatmessage(m))
# if the translated messages start with a hanging TOOL call, strip it (openai limitation)
# though hanging FUNCTION messages are OK
while translated_messages and translated_messages[0].role == "tool":
translated_messages.pop(0)
return translated_messages

async def predict(
self, messages: list[ChatMessage], functions: list[AIFunction] | None = None, **hyperparams
) -> ChatCompletion:
if functions:
tool_specs = self.translate_functions(functions)
else:
tool_specs = None
# translate to openai spec - group any tool messages together and ensure all free ToolCall IDs are bound
translated_messages = self.translate_messages(messages)
# make API call
completion = await self.client.create_chat_completion(
model=self.model, messages=translated_messages, tools=tool_specs, **self.hyperparams, **hyperparams
Expand Down

0 comments on commit 8e15a4b

Please sign in to comment.