Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Bobholamovic committed Feb 15, 2024
1 parent 8c0ea11 commit aa1894f
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ class ErnieBotChat(BaseChatModel):
"""ERNIE Bot Chat large language models API.
To use, you should have the ``erniebot`` python package installed, and the
environment variable ``AISTUDIO_ACCESS_TOKEN`` set with your AI Studio access token.
environment variable ``AISTUDIO_ACCESS_TOKEN`` set with your AI Studio
access token.
Example:
.. code-block:: python
Expand Down Expand Up @@ -133,6 +134,8 @@ def _generate(
system_prompt = self._build_system_prompt_from_messages(messages)
if system_prompt is not None:
params["system"] = system_prompt
if stop is not None:
params["stop"] = stop
params["stream"] = False
response = self.client.create(**params)
return self._build_chat_result_from_response(response)
Expand Down Expand Up @@ -161,6 +164,8 @@ async def _agenerate(
system_prompt = self._build_system_prompt_from_messages(messages)
if system_prompt is not None:
params["system"] = system_prompt
if stop is not None:
params["stop"] = stop
params["stream"] = False
response = await self.client.acreate(**params)
return self._build_chat_result_from_response(response)
Expand Down Expand Up @@ -213,7 +218,6 @@ def _build_chat_result_from_response(self, response: Mapping[str, Any]) -> ChatR
message_dict = self._build_dict_from_response(response)
generation = ChatGeneration(
message=self._convert_dict_to_message(message_dict),
generation_info=dict(finish_reason="stop"),
)
token_usage = response.get("usage", {})
llm_output = {"token_usage": token_usage, "model_name": self.model}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ class ErnieEmbeddings(BaseModel, Embeddings):
"""ERNIE embedding models.
To use, you should have the ``erniebot`` python package installed, and the
environment variable ``AISTUDIO_ACCESS_TOKEN`` set with your AI Studio access token.
environment variable ``AISTUDIO_ACCESS_TOKEN`` set with your AI Studio
access token.
Example:
.. code-block:: python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
CallbackManagerForLLMRun,
)
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import Field, root_validator
from langchain.schema.output import GenerationChunk
from langchain.utils import get_from_dict_or_env
Expand All @@ -17,7 +16,8 @@ class ErnieBot(LLM):
"""ERNIE Bot large language models.
To use, you should have the ``erniebot`` python package installed, and the
environment variable ``AISTUDIO_ACCESS_TOKEN`` set with your AI Studio access token.
environment variable ``AISTUDIO_ACCESS_TOKEN`` set with your AI Studio
access token.
Example:
.. code-block:: python
Expand Down Expand Up @@ -111,11 +111,11 @@ def _call(
params = self._invocation_params
params.update(kwargs)
params["messages"] = [self._build_user_message_from_prompt(prompt)]
if stop is not None:
params["stop"] = stop
params["stream"] = False
response = self.client.create(**params)
text = response["result"]
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text

async def _acall(
Expand All @@ -134,11 +134,11 @@ async def _acall(
params = self._invocation_params
params.update(kwargs)
params["messages"] = [self._build_user_message_from_prompt(prompt)]
if stop is not None:
params["stop"] = stop
params["stream"] = False
response = await self.client.acreate(**params)
text = response["result"]
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text

def _stream(
Expand Down Expand Up @@ -168,7 +168,7 @@ async def _astream(
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
if stop is not None:
raise TypeError("Currently, `stop` is not supported when streaming is enabled.")
raise ValueError("Currently, `stop` is not supported when streaming is enabled.")
params = self._invocation_params
params.update(kwargs)
params["messages"] = [self._build_user_message_from_prompt(prompt)]
Expand Down
16 changes: 8 additions & 8 deletions erniebot/src/erniebot/resources/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def create(
top_p: Union[float, NotGiven] = ...,
penalty_score: Union[float, NotGiven] = ...,
system: Union[str, NotGiven] = ...,
stop: Union[str, NotGiven] = ...,
stop: Union[List[str], NotGiven] = ...,
disable_search: Union[bool, NotGiven] = ...,
enable_citation: Union[bool, NotGiven] = ...,
user_id: Union[str, NotGiven] = ...,
Expand All @@ -131,7 +131,7 @@ def create(
top_p: Union[float, NotGiven] = ...,
penalty_score: Union[float, NotGiven] = ...,
system: Union[str, NotGiven] = ...,
stop: Union[str, NotGiven] = ...,
stop: Union[List[str], NotGiven] = ...,
disable_search: Union[bool, NotGiven] = ...,
enable_citation: Union[bool, NotGiven] = ...,
user_id: Union[str, NotGiven] = ...,
Expand All @@ -157,7 +157,7 @@ def create(
top_p: Union[float, NotGiven] = ...,
penalty_score: Union[float, NotGiven] = ...,
system: Union[str, NotGiven] = ...,
stop: Union[str, NotGiven] = ...,
stop: Union[List[str], NotGiven] = ...,
disable_search: Union[bool, NotGiven] = ...,
enable_citation: Union[bool, NotGiven] = ...,
user_id: Union[str, NotGiven] = ...,
Expand All @@ -182,7 +182,7 @@ def create(
top_p: Union[float, NotGiven] = NOT_GIVEN,
penalty_score: Union[float, NotGiven] = NOT_GIVEN,
system: Union[str, NotGiven] = NOT_GIVEN,
stop: Union[str, NotGiven] = NOT_GIVEN,
stop: Union[List[str], NotGiven] = NOT_GIVEN,
disable_search: Union[bool, NotGiven] = NOT_GIVEN,
enable_citation: Union[bool, NotGiven] = NOT_GIVEN,
user_id: Union[str, NotGiven] = NOT_GIVEN,
Expand Down Expand Up @@ -261,7 +261,7 @@ async def acreate(
top_p: Union[float, NotGiven] = ...,
penalty_score: Union[float, NotGiven] = ...,
system: Union[str, NotGiven] = ...,
stop: Union[str, NotGiven] = ...,
stop: Union[List[str], NotGiven] = ...,
disable_search: Union[bool, NotGiven] = ...,
enable_citation: Union[bool, NotGiven] = ...,
user_id: Union[str, NotGiven] = ...,
Expand All @@ -287,7 +287,7 @@ async def acreate(
top_p: Union[float, NotGiven] = ...,
penalty_score: Union[float, NotGiven] = ...,
system: Union[str, NotGiven] = ...,
stop: Union[str, NotGiven] = ...,
stop: Union[List[str], NotGiven] = ...,
disable_search: Union[bool, NotGiven] = ...,
enable_citation: Union[bool, NotGiven] = ...,
user_id: Union[str, NotGiven] = ...,
Expand All @@ -313,7 +313,7 @@ async def acreate(
top_p: Union[float, NotGiven] = ...,
penalty_score: Union[float, NotGiven] = ...,
system: Union[str, NotGiven] = ...,
stop: Union[str, NotGiven] = ...,
stop: Union[List[str], NotGiven] = ...,
disable_search: Union[bool, NotGiven] = ...,
enable_citation: Union[bool, NotGiven] = ...,
user_id: Union[str, NotGiven] = ...,
Expand All @@ -338,7 +338,7 @@ async def acreate(
top_p: Union[float, NotGiven] = NOT_GIVEN,
penalty_score: Union[float, NotGiven] = NOT_GIVEN,
system: Union[str, NotGiven] = NOT_GIVEN,
stop: Union[str, NotGiven] = NOT_GIVEN,
stop: Union[List[str], NotGiven] = NOT_GIVEN,
disable_search: Union[bool, NotGiven] = NOT_GIVEN,
enable_citation: Union[bool, NotGiven] = NOT_GIVEN,
user_id: Union[str, NotGiven] = NOT_GIVEN,
Expand Down

0 comments on commit aa1894f

Please sign in to comment.