Skip to content

Commit

Permalink
[aistudio api] Update weipu api (#319)
Browse files Browse the repository at this point in the history
* Update weipu api

* Updare erniebot api

* remove unused comments

* restore erniebot

* reformat

* update

* remove lines

* fix ci

* Add ernieb speed

* suport no access token config

* Fix unitest
  • Loading branch information
w5688414 committed Feb 26, 2024
1 parent 4aca819 commit 2bfa1b1
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from llama_index.schema import NodeWithScore, TextNode
from llama_index.core.schema import NodeWithScore, TextNode

from erniebot_agent.tools.llama_index_retrieval_tool import LlamaIndexRetrievalTool

Expand Down
1 change: 0 additions & 1 deletion erniebot/src/erniebot/backends/bce.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ def handle_response(cls, resp: EBResponse) -> EBResponse:
if "error_code" in resp and "error_msg" in resp:
ecode = resp["error_code"]
emsg = resp["error_msg"]
print(ecode)
if ecode in (4, 17):
raise errors.RequestLimitError(emsg, ecode=ecode)
elif ecode in (13, 15, 18):
Expand Down
17 changes: 17 additions & 0 deletions erniebot/src/erniebot/backends/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, Optional, Union

import erniebot.utils.logging as logging
from erniebot.api_types import APIType
from erniebot.backends.bce import QianfanLegacyBackend
from erniebot.response import EBResponse
Expand All @@ -29,6 +31,10 @@ class CustomBackend(EBBackend):

def __init__(self, config_dict: Dict[str, Any]) -> None:
super().__init__(config_dict=config_dict)
access_token = self._cfg.get("access_token", None)
if access_token is None:
access_token = os.environ.get("AISTUDIO_ACCESS_TOKEN", None)
self._access_token = access_token

def request(
self,
Expand Down Expand Up @@ -71,6 +77,8 @@ async def arequest(
supplied_headers=headers,
params=params,
)
if self._access_token is not None:
headers = self._add_aistudio_fields_to_headers(headers)
return await self._client.asend_request(
method,
url,
Expand All @@ -83,3 +91,12 @@ async def arequest(
@classmethod
def handle_response(cls, resp: EBResponse) -> EBResponse:
return QianfanLegacyBackend.handle_response(resp)

def _add_aistudio_fields_to_headers(self, headers: HeadersType) -> HeadersType:
if "Authorization" in headers:
logging.warning(
"Key 'Authorization' already exists in `headers`: %r",
headers["Authorization"],
)
headers["Authorization"] = f"{self._access_token}"
return headers
1 change: 0 additions & 1 deletion erniebot/src/erniebot/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,6 @@ def _interpret_response_line(
logging.debug("Decoded response body: %r", decoded_rbody)

response = EBResponse(rcode=rcode, rbody=decoded_rbody, rheaders=dict(rheaders))

if rcode != http.HTTPStatus.OK:
raise errors.HTTPRequestError(
f"The status code is not {http.HTTPStatus.OK}.",
Expand Down
11 changes: 10 additions & 1 deletion erniebot/src/erniebot/resources/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
"ernie-3.5": {
"model_id": "completions",
},
"ernie-4.0": {
"model_id": "completions_pro",
},
"ernie-longtext": {
"model_id": "ernie_bot_8k",
},
"ernie-speed": {
"model_id": "ernie_speed",
},
},
},
}
Expand Down Expand Up @@ -514,7 +523,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:

# headers
headers: HeadersType = {}
if self.api_type is APIType.AISTUDIO:
if self.api_type is APIType.AISTUDIO or self.api_type is APIType.CUSTOM:
headers["Content-Type"] = "application/json"
if "headers" in kwargs:
headers.update(kwargs["headers"])
Expand Down

0 comments on commit 2bfa1b1

Please sign in to comment.