Skip to content

Commit

Permalink
Updare erniebot api
Browse files Browse the repository at this point in the history
  • Loading branch information
w5688414 committed Feb 26, 2024
1 parent ddceb1e commit 5da2d8b
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 7 deletions.
10 changes: 10 additions & 0 deletions erniebot-agent/src/erniebot_agent/chat_models/erniebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,16 @@ def convert_response_to_output(response: ChatCompletionResponse, output_type: Ty
clarify=clarify,
)
else:

# return output_type(
# content=response.result,
# function_call=None,
# plugin_info=None,
# search_info=None,
# token_usage=response.usage,
# clarify=clarify,
# )

return output_type(
content=response.rbody,
function_call=None,
Expand Down
4 changes: 2 additions & 2 deletions erniebot/src/erniebot/backends/aistudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ async def arequest(
params=params,
)
headers = self._add_aistudio_fields_to_headers(headers)

return await self._client.asend_request(
method,
url,
Expand Down Expand Up @@ -117,6 +118,5 @@ def _add_aistudio_fields_to_headers(self, headers: HeadersType) -> HeadersType:
"Key 'Authorization' already exists in `headers`: %r",
headers["Authorization"],
)
# headers["Authorization"] = f"token {self._access_token}"
headers["Authorization"] = f"{self._access_token}"
headers["Authorization"] = f"token {self._access_token}"
return headers
3 changes: 1 addition & 2 deletions erniebot/src/erniebot/backends/bce.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def handle_response(cls, resp: EBResponse) -> EBResponse:
else:
raise errors.APIError(emsg, ecode=ecode)
else:
return resp
return EBResponse(resp.rcode, resp.result, resp.rheaders)


class YinianBackend(_BCELegacyBackend):
Expand All @@ -349,7 +349,6 @@ def handle_response(cls, resp: EBResponse) -> EBResponse:
if "error_code" in resp and "error_msg" in resp:
ecode = resp["error_code"]
emsg = resp["error_msg"]
print(ecode)
if ecode in (4, 17):
raise errors.RequestLimitError(emsg, ecode=ecode)
elif ecode in (13, 15, 18):
Expand Down
19 changes: 19 additions & 0 deletions erniebot/src/erniebot/backends/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import Any, AsyncIterator, ClassVar, Dict, Iterator, Optional, Union

import erniebot.utils.logging as logging
from erniebot.api_types import APIType
from erniebot.backends.bce import QianfanLegacyBackend
from erniebot.response import EBResponse
Expand All @@ -29,6 +31,12 @@ class CustomBackend(EBBackend):

def __init__(self, config_dict: Dict[str, Any]) -> None:
super().__init__(config_dict=config_dict)
access_token = self._cfg.get("access_token", None)
if access_token is None:
access_token = os.environ.get("AISTUDIO_ACCESS_TOKEN", None)
if access_token is None:
raise RuntimeError("No access token is configured.")
self._access_token = access_token

def request(
self,
Expand Down Expand Up @@ -71,6 +79,7 @@ async def arequest(
supplied_headers=headers,
params=params,
)
headers = self._add_aistudio_fields_to_headers(headers)
return await self._client.asend_request(
method,
url,
Expand All @@ -83,3 +92,13 @@ async def arequest(
@classmethod
def handle_response(cls, resp: EBResponse) -> EBResponse:
return QianfanLegacyBackend.handle_response(resp)

def _add_aistudio_fields_to_headers(self, headers: HeadersType) -> HeadersType:
if "Authorization" in headers:
logging.warning(
"Key 'Authorization' already exists in `headers`: %r",
headers["Authorization"],
)
# headers["Authorization"] = f"token {self._access_token}"
headers["Authorization"] = f"{self._access_token}"
return headers
3 changes: 1 addition & 2 deletions erniebot/src/erniebot/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,9 +409,8 @@ def _interpret_response_line(
)

logging.debug("Decoded response body: %r", decoded_rbody)

breakpoint()
response = EBResponse(rcode=rcode, rbody=decoded_rbody, rheaders=dict(rheaders))

if rcode != http.HTTPStatus.OK:
raise errors.HTTPRequestError(
f"The status code is not {http.HTTPStatus.OK}.",
Expand Down
11 changes: 10 additions & 1 deletion erniebot/src/erniebot/resources/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
"ernie-longtext": {
"model_id": "ernie_bot_8k",
},
"ernie-speed": {
"model_id": "ernie_speed",
},
},
},
APIType.CUSTOM: {
Expand All @@ -92,6 +95,12 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
"ernie-3.5": {
"model_id": "completions",
},
"ernie-4.0": {
"model_id": "completions_pro",
},
"ernie-longtext": {
"model_id": "ernie_bot_8k",
},
},
},
}
Expand Down Expand Up @@ -514,7 +523,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:

# headers
headers: HeadersType = {}
if self.api_type is APIType.AISTUDIO:
if self.api_type is APIType.AISTUDIO or self.api_type is APIType.CUSTOM:
headers["Content-Type"] = "application/json"
if "headers" in kwargs:
headers.update(kwargs["headers"])
Expand Down

0 comments on commit 5da2d8b

Please sign in to comment.