From aa1894f4c396a1cbec438206638f1dbe6a1b4c22 Mon Sep 17 00:00:00 2001 From: Bobholamovic Date: Thu, 15 Feb 2024 17:53:22 +0800 Subject: [PATCH] Fix bugs --- .../extensions/langchain/chat_models/erniebot.py | 8 ++++++-- .../extensions/langchain/embeddings/ernie.py | 3 ++- .../extensions/langchain/llms/erniebot.py | 14 +++++++------- .../src/erniebot/resources/chat_completion.py | 16 ++++++++-------- 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/erniebot-agent/src/erniebot_agent/extensions/langchain/chat_models/erniebot.py b/erniebot-agent/src/erniebot_agent/extensions/langchain/chat_models/erniebot.py index f1104c5ad..e8252fb31 100644 --- a/erniebot-agent/src/erniebot_agent/extensions/langchain/chat_models/erniebot.py +++ b/erniebot-agent/src/erniebot_agent/extensions/langchain/chat_models/erniebot.py @@ -28,7 +28,8 @@ class ErnieBotChat(BaseChatModel): """ERNIE Bot Chat large language models API. To use, you should have the ``erniebot`` python package installed, and the - environment variable ``AISTUDIO_ACCESS_TOKEN`` set with your AI Studio access token. + environment variable ``AISTUDIO_ACCESS_TOKEN`` set with your AI Studio + access token. Example: .. code-block:: python @@ -133,6 +134,8 @@ def _generate( system_prompt = self._build_system_prompt_from_messages(messages) if system_prompt is not None: params["system"] = system_prompt + if stop is not None: + params["stop"] = stop params["stream"] = False response = self.client.create(**params) return self._build_chat_result_from_response(response) @@ -161,6 +164,8 @@ async def _agenerate( system_prompt = self._build_system_prompt_from_messages(messages) if system_prompt is not None: params["system"] = system_prompt + if stop is not None: + params["stop"] = stop params["stream"] = False response = await self.client.acreate(**params) return self._build_chat_result_from_response(response) @@ -213,7 +218,6 @@ def _build_chat_result_from_response(self, response: Mapping[str, Any]) -> ChatR message_dict = self._build_dict_from_response(response) generation = ChatGeneration( message=self._convert_dict_to_message(message_dict), - generation_info=dict(finish_reason="stop"), ) token_usage = response.get("usage", {}) llm_output = {"token_usage": token_usage, "model_name": self.model} diff --git a/erniebot-agent/src/erniebot_agent/extensions/langchain/embeddings/ernie.py b/erniebot-agent/src/erniebot_agent/extensions/langchain/embeddings/ernie.py index 3d5a662d3..13e61d45f 100644 --- a/erniebot-agent/src/erniebot_agent/extensions/langchain/embeddings/ernie.py +++ b/erniebot-agent/src/erniebot_agent/extensions/langchain/embeddings/ernie.py @@ -11,7 +11,8 @@ class ErnieEmbeddings(BaseModel, Embeddings): """ERNIE embedding models. To use, you should have the ``erniebot`` python package installed, and the - environment variable ``AISTUDIO_ACCESS_TOKEN`` set with your AI Studio access token. + environment variable ``AISTUDIO_ACCESS_TOKEN`` set with your AI Studio + access token. Example: .. code-block:: python diff --git a/erniebot-agent/src/erniebot_agent/extensions/langchain/llms/erniebot.py b/erniebot-agent/src/erniebot_agent/extensions/langchain/llms/erniebot.py index a17827f62..1a3c45ca9 100644 --- a/erniebot-agent/src/erniebot_agent/extensions/langchain/llms/erniebot.py +++ b/erniebot-agent/src/erniebot_agent/extensions/langchain/llms/erniebot.py @@ -7,7 +7,6 @@ CallbackManagerForLLMRun, ) from langchain.llms.base import LLM -from langchain.llms.utils import enforce_stop_tokens from langchain.pydantic_v1 import Field, root_validator from langchain.schema.output import GenerationChunk from langchain.utils import get_from_dict_or_env @@ -17,7 +16,8 @@ class ErnieBot(LLM): """ERNIE Bot large language models. To use, you should have the ``erniebot`` python package installed, and the - environment variable ``AISTUDIO_ACCESS_TOKEN`` set with your AI Studio access token. + environment variable ``AISTUDIO_ACCESS_TOKEN`` set with your AI Studio + access token. Example: .. code-block:: python @@ -111,11 +111,11 @@ def _call( params = self._invocation_params params.update(kwargs) params["messages"] = [self._build_user_message_from_prompt(prompt)] + if stop is not None: + params["stop"] = stop params["stream"] = False response = self.client.create(**params) text = response["result"] - if stop is not None: - text = enforce_stop_tokens(text, stop) return text async def _acall( @@ -134,11 +134,11 @@ async def _acall( params = self._invocation_params params.update(kwargs) params["messages"] = [self._build_user_message_from_prompt(prompt)] + if stop is not None: + params["stop"] = stop params["stream"] = False response = await self.client.acreate(**params) text = response["result"] - if stop is not None: - text = enforce_stop_tokens(text, stop) return text def _stream( @@ -168,7 +168,7 @@ async def _astream( **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: if stop is not None: - raise TypeError("Currently, `stop` is not supported when streaming is enabled.") + raise ValueError("Currently, `stop` is not supported when streaming is enabled.") params = self._invocation_params params.update(kwargs) params["messages"] = [self._build_user_message_from_prompt(prompt)] diff --git a/erniebot/src/erniebot/resources/chat_completion.py b/erniebot/src/erniebot/resources/chat_completion.py index ee0982836..d5c6dc4ab 100644 --- a/erniebot/src/erniebot/resources/chat_completion.py +++ b/erniebot/src/erniebot/resources/chat_completion.py @@ -105,7 +105,7 @@ def create( top_p: Union[float, NotGiven] = ..., penalty_score: Union[float, NotGiven] = ..., system: Union[str, NotGiven] = ..., - stop: Union[str, NotGiven] = ..., + stop: Union[List[str], NotGiven] = ..., disable_search: Union[bool, NotGiven] = ..., enable_citation: Union[bool, NotGiven] = ..., user_id: Union[str, NotGiven] = ..., @@ -131,7 +131,7 @@ def create( top_p: Union[float, NotGiven] = ..., penalty_score: Union[float, NotGiven] = ..., system: Union[str, NotGiven] = ..., - stop: Union[str, NotGiven] = ..., + stop: Union[List[str], NotGiven] = ..., disable_search: Union[bool, NotGiven] = ..., enable_citation: Union[bool, NotGiven] = ..., user_id: Union[str, NotGiven] = ..., @@ -157,7 +157,7 @@ def create( top_p: Union[float, NotGiven] = ..., penalty_score: Union[float, NotGiven] = ..., system: Union[str, NotGiven] = ..., - stop: Union[str, NotGiven] = ..., + stop: Union[List[str], NotGiven] = ..., disable_search: Union[bool, NotGiven] = ..., enable_citation: Union[bool, NotGiven] = ..., user_id: Union[str, NotGiven] = ..., @@ -182,7 +182,7 @@ def create( top_p: Union[float, NotGiven] = NOT_GIVEN, penalty_score: Union[float, NotGiven] = NOT_GIVEN, system: Union[str, NotGiven] = NOT_GIVEN, - stop: Union[str, NotGiven] = NOT_GIVEN, + stop: Union[List[str], NotGiven] = NOT_GIVEN, disable_search: Union[bool, NotGiven] = NOT_GIVEN, enable_citation: Union[bool, NotGiven] = NOT_GIVEN, user_id: Union[str, NotGiven] = NOT_GIVEN, @@ -261,7 +261,7 @@ async def acreate( top_p: Union[float, NotGiven] = ..., penalty_score: Union[float, NotGiven] = ..., system: Union[str, NotGiven] = ..., - stop: Union[str, NotGiven] = ..., + stop: Union[List[str], NotGiven] = ..., disable_search: Union[bool, NotGiven] = ..., enable_citation: Union[bool, NotGiven] = ..., user_id: Union[str, NotGiven] = ..., @@ -287,7 +287,7 @@ async def acreate( top_p: Union[float, NotGiven] = ..., penalty_score: Union[float, NotGiven] = ..., system: Union[str, NotGiven] = ..., - stop: Union[str, NotGiven] = ..., + stop: Union[List[str], NotGiven] = ..., disable_search: Union[bool, NotGiven] = ..., enable_citation: Union[bool, NotGiven] = ..., user_id: Union[str, NotGiven] = ..., @@ -313,7 +313,7 @@ async def acreate( top_p: Union[float, NotGiven] = ..., penalty_score: Union[float, NotGiven] = ..., system: Union[str, NotGiven] = ..., - stop: Union[str, NotGiven] = ..., + stop: Union[List[str], NotGiven] = ..., disable_search: Union[bool, NotGiven] = ..., enable_citation: Union[bool, NotGiven] = ..., user_id: Union[str, NotGiven] = ..., @@ -338,7 +338,7 @@ async def acreate( top_p: Union[float, NotGiven] = NOT_GIVEN, penalty_score: Union[float, NotGiven] = NOT_GIVEN, system: Union[str, NotGiven] = NOT_GIVEN, - stop: Union[str, NotGiven] = NOT_GIVEN, + stop: Union[List[str], NotGiven] = NOT_GIVEN, disable_search: Union[bool, NotGiven] = NOT_GIVEN, enable_citation: Union[bool, NotGiven] = NOT_GIVEN, user_id: Union[str, NotGiven] = NOT_GIVEN,