Skip to content

Commit

Permalink
LLM: Add support for JSON mode (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hialus authored May 4, 2024
1 parent 7b01f18 commit e95f0da
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 2 deletions.
12 changes: 11 additions & 1 deletion app/llm/completion_arguments.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
from enum import Enum

CompletionArgumentsResponseFormat = Enum("TEXT", "JSON")


class CompletionArguments:
"""Arguments for the completion request"""

def __init__(
self, max_tokens: int = None, temperature: float = None, stop: list[str] = None
self,
max_tokens: int = None,
temperature: float = None,
stop: list[str] = None,
response_format: CompletionArgumentsResponseFormat = "TEXT",
):
self.max_tokens = max_tokens
self.temperature = temperature
self.stop = stop
self.response_format = response_format
2 changes: 2 additions & 0 deletions app/llm/external/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def complete(
model=self.model,
prompt=prompt,
images=[image.base64] if image else None,
format="json" if arguments.response_format == "JSON" else "",
options=self.options,
)
return response["response"]
Expand All @@ -103,6 +104,7 @@ def chat(
response = self._client.chat(
model=self.model,
messages=convert_to_ollama_messages(messages),
format="json" if arguments.response_format == "JSON" else "",
options=self.options,
)
return convert_to_iris_message(response["message"])
Expand Down
5 changes: 5 additions & 0 deletions app/llm/external/openai_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from openai import OpenAI
from openai.lib.azure import AzureOpenAI
from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam
from openai.types.chat.completion_create_params import ResponseFormat

from ...common.message_converters import map_str_to_role, map_role_to_str
from app.domain.data.text_message_content_dto import TextMessageContentDTO
Expand Down Expand Up @@ -76,11 +77,15 @@ class OpenAIChatModel(ChatModel):
def chat(
self, messages: list[PyrisMessage], arguments: CompletionArguments
) -> PyrisMessage:
# noinspection PyTypeChecker
response = self._client.chat.completions.create(
model=self.model,
messages=convert_to_open_ai_messages(messages),
temperature=arguments.temperature,
max_tokens=arguments.max_tokens,
response_format=ResponseFormat(
type=("json_object" if arguments.response_format == "JSON" else "text")
),
)
return convert_to_iris_message(response.choices[0].message)

Expand Down
6 changes: 5 additions & 1 deletion app/pipeline/chat/file_selector_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,13 @@ def __init__(self, callback: Optional[StatusCallback] = None):
requirements=RequirementList(
gpt_version_equivalent=3.5,
context_length=4096,
vendor="OpenAI",
json_mode=True,
)
)
completion_args = CompletionArguments(temperature=0, max_tokens=500)
completion_args = CompletionArguments(
temperature=0, max_tokens=500, response_format="JSON"
)
self.llm = IrisLangchainChatModel(
request_handler=request_handler, completion_args=completion_args
)
Expand Down

0 comments on commit e95f0da

Please sign in to comment.