Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aistudio api] Update weipu api #319

Merged
merged 11 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from llama_index.schema import NodeWithScore, TextNode
from llama_index.core.schema import NodeWithScore, TextNode

from erniebot_agent.tools.llama_index_retrieval_tool import LlamaIndexRetrievalTool

Expand Down
1 change: 0 additions & 1 deletion erniebot/src/erniebot/backends/bce.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ def handle_response(cls, resp: EBResponse) -> EBResponse:
if "error_code" in resp and "error_msg" in resp:
ecode = resp["error_code"]
emsg = resp["error_msg"]
print(ecode)
if ecode in (4, 17):
raise errors.RequestLimitError(emsg, ecode=ecode)
elif ecode in (13, 15, 18):
Expand Down
17 changes: 17 additions & 0 deletions erniebot/src/erniebot/backends/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, Optional, Union

import erniebot.utils.logging as logging
from erniebot.api_types import APIType
from erniebot.backends.bce import QianfanLegacyBackend
from erniebot.response import EBResponse
Expand All @@ -29,6 +31,10 @@ class CustomBackend(EBBackend):

def __init__(self, config_dict: Dict[str, Any]) -> None:
super().__init__(config_dict=config_dict)
access_token = self._cfg.get("access_token", None)
if access_token is None:
access_token = os.environ.get("AISTUDIO_ACCESS_TOKEN", None)
self._access_token = access_token

def request(
self,
Expand Down Expand Up @@ -71,6 +77,8 @@ async def arequest(
supplied_headers=headers,
params=params,
)
if self._access_token is not None:
headers = self._add_aistudio_fields_to_headers(headers)
return await self._client.asend_request(
method,
url,
Expand All @@ -83,3 +91,12 @@ async def arequest(
@classmethod
def handle_response(cls, resp: EBResponse) -> EBResponse:
return QianfanLegacyBackend.handle_response(resp)

def _add_aistudio_fields_to_headers(self, headers: HeadersType) -> HeadersType:
if "Authorization" in headers:
logging.warning(
"Key 'Authorization' already exists in `headers`: %r",
headers["Authorization"],
)
headers["Authorization"] = f"{self._access_token}"
return headers
1 change: 0 additions & 1 deletion erniebot/src/erniebot/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,6 @@ def _interpret_response_line(
logging.debug("Decoded response body: %r", decoded_rbody)

response = EBResponse(rcode=rcode, rbody=decoded_rbody, rheaders=dict(rheaders))

if rcode != http.HTTPStatus.OK:
raise errors.HTTPRequestError(
f"The status code is not {http.HTTPStatus.OK}.",
Expand Down
11 changes: 10 additions & 1 deletion erniebot/src/erniebot/resources/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
"ernie-3.5": {
"model_id": "completions",
},
"ernie-4.0": {
"model_id": "completions_pro",
},
"ernie-longtext": {
"model_id": "ernie_bot_8k",
},
"ernie-speed": {
"model_id": "ernie_speed",
},
},
},
}
Expand Down Expand Up @@ -514,7 +523,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:

# headers
headers: HeadersType = {}
if self.api_type is APIType.AISTUDIO:
if self.api_type is APIType.AISTUDIO or self.api_type is APIType.CUSTOM:
headers["Content-Type"] = "application/json"
if "headers" in kwargs:
headers.update(kwargs["headers"])
Expand Down
Loading