From e95f0da74d1d7af20ce3941eb90bfc4d593c79a4 Mon Sep 17 00:00:00 2001 From: Timor Morrien Date: Sat, 4 May 2024 14:29:38 +0200 Subject: [PATCH] `LLM`: Add support for JSON mode (#102) --- app/llm/completion_arguments.py | 12 +++++++++++- app/llm/external/ollama.py | 2 ++ app/llm/external/openai_chat.py | 5 +++++ app/pipeline/chat/file_selector_pipeline.py | 6 +++++- 4 files changed, 23 insertions(+), 2 deletions(-) diff --git a/app/llm/completion_arguments.py b/app/llm/completion_arguments.py index a540e144..29beaa9f 100644 --- a/app/llm/completion_arguments.py +++ b/app/llm/completion_arguments.py @@ -1,9 +1,19 @@ +from enum import Enum + +CompletionArgumentsResponseFormat = Enum("TEXT", "JSON") + + class CompletionArguments: """Arguments for the completion request""" def __init__( - self, max_tokens: int = None, temperature: float = None, stop: list[str] = None + self, + max_tokens: int = None, + temperature: float = None, + stop: list[str] = None, + response_format: CompletionArgumentsResponseFormat = "TEXT", ): self.max_tokens = max_tokens self.temperature = temperature self.stop = stop + self.response_format = response_format diff --git a/app/llm/external/ollama.py b/app/llm/external/ollama.py index d6946e31..81b02d32 100644 --- a/app/llm/external/ollama.py +++ b/app/llm/external/ollama.py @@ -93,6 +93,7 @@ def complete( model=self.model, prompt=prompt, images=[image.base64] if image else None, + format="json" if arguments.response_format == "JSON" else "", options=self.options, ) return response["response"] @@ -103,6 +104,7 @@ def chat( response = self._client.chat( model=self.model, messages=convert_to_ollama_messages(messages), + format="json" if arguments.response_format == "JSON" else "", options=self.options, ) return convert_to_iris_message(response["message"]) diff --git a/app/llm/external/openai_chat.py b/app/llm/external/openai_chat.py index 894b3b18..dde7d3f0 100644 --- a/app/llm/external/openai_chat.py +++ b/app/llm/external/openai_chat.py @@ -4,6 +4,7 @@ from openai import OpenAI from openai.lib.azure import AzureOpenAI from openai.types.chat import ChatCompletionMessage, ChatCompletionMessageParam +from openai.types.chat.completion_create_params import ResponseFormat from ...common.message_converters import map_str_to_role, map_role_to_str from app.domain.data.text_message_content_dto import TextMessageContentDTO @@ -76,11 +77,15 @@ class OpenAIChatModel(ChatModel): def chat( self, messages: list[PyrisMessage], arguments: CompletionArguments ) -> PyrisMessage: + # noinspection PyTypeChecker response = self._client.chat.completions.create( model=self.model, messages=convert_to_open_ai_messages(messages), temperature=arguments.temperature, max_tokens=arguments.max_tokens, + response_format=ResponseFormat( + type=("json_object" if arguments.response_format == "JSON" else "text") + ), ) return convert_to_iris_message(response.choices[0].message) diff --git a/app/pipeline/chat/file_selector_pipeline.py b/app/pipeline/chat/file_selector_pipeline.py index 129a241f..a50e2c69 100644 --- a/app/pipeline/chat/file_selector_pipeline.py +++ b/app/pipeline/chat/file_selector_pipeline.py @@ -46,9 +46,13 @@ def __init__(self, callback: Optional[StatusCallback] = None): requirements=RequirementList( gpt_version_equivalent=3.5, context_length=4096, + vendor="OpenAI", + json_mode=True, ) ) - completion_args = CompletionArguments(temperature=0, max_tokens=500) + completion_args = CompletionArguments( + temperature=0, max_tokens=500, response_format="JSON" + ) self.llm = IrisLangchainChatModel( request_handler=request_handler, completion_args=completion_args )