From f10ad872cc95e79a0a3b56a342d49abf18ca19d3 Mon Sep 17 00:00:00 2001 From: Southpika <122620817+Southpika@users.noreply.github.com> Date: Tue, 23 Jan 2024 18:59:53 +0800 Subject: [PATCH] [Enhancement][AIStudio]Add json response (#308) * add json res * add warning * add warning * reformat --- .../erniebot_agent/chat_models/erniebot.py | 31 ++++++++++++++++++- .../src/erniebot/resources/chat_completion.py | 12 +++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/erniebot-agent/src/erniebot_agent/chat_models/erniebot.py b/erniebot-agent/src/erniebot_agent/chat_models/erniebot.py index beda26b0e..6a1b5c04c 100644 --- a/erniebot-agent/src/erniebot_agent/chat_models/erniebot.py +++ b/erniebot-agent/src/erniebot_agent/chat_models/erniebot.py @@ -13,6 +13,7 @@ # limitations under the License. import json +import logging from typing import ( Any, AsyncIterator, @@ -44,6 +45,9 @@ _T = TypeVar("_T", AIMessage, AIMessageChunk) +_logger = logging.getLogger(__name__) + + class BaseERNIEBot(ChatModel): @overload async def chat( @@ -215,7 +219,15 @@ def _generate_config(self, messages: List[Message], functions, **kwargs) -> dict if functions is not None: cfg_dict["functions"] = functions - name_list = ["top_p", "temperature", "penalty_score", "system", "plugins", "tool_choice"] + name_list = [ + "top_p", + "temperature", + "penalty_score", + "system", + "plugins", + "tool_choice", + "response_format", + ] for name in name_list: if name in kwargs: cfg_dict[name] = kwargs[name] @@ -227,6 +239,23 @@ def _generate_config(self, messages: List[Message], functions, **kwargs) -> dict # rm blank dict if not cfg_dict["tool_choice"]: cfg_dict.pop("tool_choice") + + if "response_format" in cfg_dict: + if cfg_dict["response_format"] not in ("json_object", "text"): + if "json" in cfg_dict["response_format"]: + cfg_dict["response_format"] = "json_object" + _logger.warning( + f"`response_format` has invalid value:`{cfg_dict['response_format']}`, " + "use `json_object` instead. " + ) + else: + # It will not raise error in request + _logger.warning( + f"`response_format` has invalid value:`{cfg_dict['response_format']}`, " + "use default value: `text`. " + "You can only choose `json_object` or `text`. " + ) + return cfg_dict def _maybe_validate_qianfan_auth(self) -> None: diff --git a/erniebot/src/erniebot/resources/chat_completion.py b/erniebot/src/erniebot/resources/chat_completion.py index ee0982836..8890e7f7e 100644 --- a/erniebot/src/erniebot/resources/chat_completion.py +++ b/erniebot/src/erniebot/resources/chat_completion.py @@ -115,6 +115,7 @@ def create( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., _config_: Optional[ConfigDictType] = ..., ) -> "ChatCompletionResponse": ... @@ -141,6 +142,7 @@ def create( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., _config_: Optional[ConfigDictType] = ..., ) -> Iterator["ChatCompletionResponse"]: ... @@ -167,6 +169,7 @@ def create( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = ..., _config_: Optional[ConfigDictType] = ..., ) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]: ... @@ -192,6 +195,7 @@ def create( extra_params: Optional[dict] = None, headers: Optional[HeadersType] = None, request_timeout: Optional[float] = None, + response_format: Optional[Literal["json_object", "text"]] = None, _config_: Optional[ConfigDictType] = None, ) -> Union["ChatCompletionResponse", Iterator["ChatCompletionResponse"]]: """Creates a model response for the given conversation. @@ -238,6 +242,7 @@ def create( user_id=user_id, tool_choice=tool_choice, stream=stream, + response_format=response_format, ) kwargs["validate_functions"] = validate_functions if extra_params is not None: @@ -271,6 +276,7 @@ async def acreate( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = None, _config_: Optional[ConfigDictType] = ..., ) -> EBResponse: ... @@ -297,6 +303,7 @@ async def acreate( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = None, _config_: Optional[ConfigDictType] = ..., ) -> AsyncIterator["ChatCompletionResponse"]: ... @@ -323,6 +330,7 @@ async def acreate( extra_params: Optional[dict] = ..., headers: Optional[HeadersType] = ..., request_timeout: Optional[float] = ..., + response_format: Optional[Literal["json_object", "text"]] = None, _config_: Optional[ConfigDictType] = ..., ) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]: ... @@ -348,6 +356,7 @@ async def acreate( extra_params: Optional[dict] = None, headers: Optional[HeadersType] = None, request_timeout: Optional[float] = None, + response_format: Optional[Literal["json_object", "text"]] = None, _config_: Optional[ConfigDictType] = None, ) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]: """Creates a model response for the given conversation. @@ -394,6 +403,7 @@ async def acreate( user_id=user_id, tool_choice=tool_choice, stream=stream, + response_format=response_format, ) kwargs["validate_functions"] = validate_functions if extra_params is not None: @@ -438,6 +448,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None: "extra_params", "headers", "request_timeout", + "response_format", } invalid_keys = kwargs.keys() - valid_keys @@ -500,6 +511,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None: _set_val_if_key_exists(kwargs, params, "user_id") _set_val_if_key_exists(kwargs, params, "tool_choice") _set_val_if_key_exists(kwargs, params, "stream") + _set_val_if_key_exists(kwargs, params, "response_format") if "extra_params" in kwargs: params.update(kwargs["extra_params"])