From 2bfa1b1649d15c89b92794289341e6881a9656d7 Mon Sep 17 00:00:00 2001 From: w5688414 Date: Mon, 26 Feb 2024 14:35:44 +0800 Subject: [PATCH] [aistudio api] Update weipu api (#319) * Update weipu api * Updare erniebot api * remove unused comments * restore erniebot * reformat * update * remove lines * fix ci * Add ernieb speed * suport no access token config * Fix unitest --- .../tools/test_llama_index_retrieval_tool.py | 2 +- erniebot/src/erniebot/backends/bce.py | 1 - erniebot/src/erniebot/backends/custom.py | 17 +++++++++++++++++ erniebot/src/erniebot/http_client.py | 1 - .../src/erniebot/resources/chat_completion.py | 11 ++++++++++- 5 files changed, 28 insertions(+), 4 deletions(-) diff --git a/erniebot-agent/tests/unit_tests/tools/test_llama_index_retrieval_tool.py b/erniebot-agent/tests/unit_tests/tools/test_llama_index_retrieval_tool.py index 164f1f279..88afee00d 100644 --- a/erniebot-agent/tests/unit_tests/tools/test_llama_index_retrieval_tool.py +++ b/erniebot-agent/tests/unit_tests/tools/test_llama_index_retrieval_tool.py @@ -1,5 +1,5 @@ import pytest -from llama_index.schema import NodeWithScore, TextNode +from llama_index.core.schema import NodeWithScore, TextNode from erniebot_agent.tools.llama_index_retrieval_tool import LlamaIndexRetrievalTool diff --git a/erniebot/src/erniebot/backends/bce.py b/erniebot/src/erniebot/backends/bce.py index 9f311b45e..4099a0514 100644 --- a/erniebot/src/erniebot/backends/bce.py +++ b/erniebot/src/erniebot/backends/bce.py @@ -349,7 +349,6 @@ def handle_response(cls, resp: EBResponse) -> EBResponse: if "error_code" in resp and "error_msg" in resp: ecode = resp["error_code"] emsg = resp["error_msg"] - print(ecode) if ecode in (4, 17): raise errors.RequestLimitError(emsg, ecode=ecode) elif ecode in (13, 15, 18): diff --git a/erniebot/src/erniebot/backends/custom.py b/erniebot/src/erniebot/backends/custom.py index 90df4a3d7..cd708751a 100644 --- a/erniebot/src/erniebot/backends/custom.py +++ b/erniebot/src/erniebot/backends/custom.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, Optional, Union +import erniebot.utils.logging as logging from erniebot.api_types import APIType from erniebot.backends.bce import QianfanLegacyBackend from erniebot.response import EBResponse @@ -29,6 +31,10 @@ class CustomBackend(EBBackend): def __init__(self, config_dict: Dict[str, Any]) -> None: super().__init__(config_dict=config_dict) + access_token = self._cfg.get("access_token", None) + if access_token is None: + access_token = os.environ.get("AISTUDIO_ACCESS_TOKEN", None) + self._access_token = access_token def request( self, @@ -71,6 +77,8 @@ async def arequest( supplied_headers=headers, params=params, ) + if self._access_token is not None: + headers = self._add_aistudio_fields_to_headers(headers) return await self._client.asend_request( method, url, @@ -83,3 +91,12 @@ async def arequest( @classmethod def handle_response(cls, resp: EBResponse) -> EBResponse: return QianfanLegacyBackend.handle_response(resp) + + def _add_aistudio_fields_to_headers(self, headers: HeadersType) -> HeadersType: + if "Authorization" in headers: + logging.warning( + "Key 'Authorization' already exists in `headers`: %r", + headers["Authorization"], + ) + headers["Authorization"] = f"{self._access_token}" + return headers diff --git a/erniebot/src/erniebot/http_client.py b/erniebot/src/erniebot/http_client.py index 96245d714..f7d0390f9 100644 --- a/erniebot/src/erniebot/http_client.py +++ b/erniebot/src/erniebot/http_client.py @@ -411,7 +411,6 @@ def _interpret_response_line( logging.debug("Decoded response body: %r", decoded_rbody) response = EBResponse(rcode=rcode, rbody=decoded_rbody, rheaders=dict(rheaders)) - if rcode != http.HTTPStatus.OK: raise errors.HTTPRequestError( f"The status code is not {http.HTTPStatus.OK}.", diff --git a/erniebot/src/erniebot/resources/chat_completion.py b/erniebot/src/erniebot/resources/chat_completion.py index ff3bfe2de..605b0bd4a 100644 --- a/erniebot/src/erniebot/resources/chat_completion.py +++ b/erniebot/src/erniebot/resources/chat_completion.py @@ -92,6 +92,15 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "ernie-3.5": { "model_id": "completions", }, + "ernie-4.0": { + "model_id": "completions_pro", + }, + "ernie-longtext": { + "model_id": "ernie_bot_8k", + }, + "ernie-speed": { + "model_id": "ernie_speed", + }, }, }, } @@ -514,7 +523,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None: # headers headers: HeadersType = {} - if self.api_type is APIType.AISTUDIO: + if self.api_type is APIType.AISTUDIO or self.api_type is APIType.CUSTOM: headers["Content-Type"] = "application/json" if "headers" in kwargs: headers.update(kwargs["headers"])