Skip to content

Commit

Permalink
refactor: split ChatMessage and OpenAIChatMessage
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Oct 24, 2023
1 parent 4e5735d commit c9ff0a2
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
9 changes: 4 additions & 5 deletions kani/engines/openai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import aiohttp
import pydantic

from kani.models import ChatMessage
from .models import ChatCompletion, Completion, FunctionSpec, SpecificFunctionCall
from .models import ChatCompletion, Completion, FunctionSpec, OpenAIChatMessage, SpecificFunctionCall
from ..httpclient import BaseClient, HTTPException, HTTPStatusException, HTTPTimeout


Expand Down Expand Up @@ -107,7 +106,7 @@ async def create_completion(self, model: str, **kwargs) -> Completion:
async def create_chat_completion(
self,
model: str,
messages: list[ChatMessage],
messages: list[OpenAIChatMessage],
functions: list[FunctionSpec] | None = None,
function_call: SpecificFunctionCall | Literal["auto"] | Literal["none"] | None = None,
temperature: float = 1.0,
Expand All @@ -124,7 +123,7 @@ async def create_chat_completion(
async def create_chat_completion(
self,
model: str,
messages: list[ChatMessage],
messages: list[OpenAIChatMessage],
functions: list[FunctionSpec] | None = None,
**kwargs,
) -> ChatCompletion:
Expand All @@ -142,7 +141,7 @@ async def create_chat_completion(
"/chat/completions",
json={
"model": model,
"messages": [cm.model_dump(exclude_unset=True, mode="json") for cm in messages],
"messages": [cm.model_dump(exclude_defaults=True, mode="json") for cm in messages],
**kwargs,
},
)
Expand Down
5 changes: 3 additions & 2 deletions kani/engines/openai/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from kani.models import ChatMessage
from . import function_calling
from .client import OpenAIClient
from .models import ChatCompletion, FunctionSpec
from .models import ChatCompletion, FunctionSpec, OpenAIChatMessage
from ..base import BaseEngine

try:
Expand Down Expand Up @@ -114,8 +114,9 @@ async def predict(
function_spec = [FunctionSpec(name=f.name, description=f.desc, parameters=f.json_schema) for f in functions]
else:
function_spec = None
translated_messages = [OpenAIChatMessage.from_chatmessage(m) for m in messages]
completion = await self.client.create_chat_completion(
model=self.model, messages=messages, functions=function_spec, **self.hyperparams, **hyperparams
model=self.model, messages=translated_messages, functions=function_spec, **self.hyperparams, **hyperparams
)
return completion

Expand Down
14 changes: 13 additions & 1 deletion kani/engines/openai/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Literal

from kani.models import BaseModel, ChatMessage
from kani.models import BaseModel, ChatMessage, ChatRole, FunctionCall
from ..base import BaseCompletion


Expand Down Expand Up @@ -50,8 +50,20 @@ class SpecificFunctionCall(BaseModel):
name: str


class OpenAIChatMessage(BaseModel):
role: ChatRole
content: str | list[BaseModel | str] | None
name: str | None = None
function_call: FunctionCall | None = None

@classmethod
def from_chatmessage(cls, m: ChatMessage):
return cls(role=m.role, content=m.text, name=m.name, function_call=m.function_call)


# ---- response ----
class ChatCompletionChoice(BaseModel):
# this is a ChatMessage rather than an OpenAIChatMessage because all engines need to return the kani model
message: ChatMessage
index: int
finish_reason: str | None = None
Expand Down

0 comments on commit c9ff0a2

Please sign in to comment.