Skip to content

Commit

Permalink
add check_models
Browse files Browse the repository at this point in the history
  • Loading branch information
wj-Mcat committed Apr 11, 2024
1 parent 6adf37a commit 00246d6
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions erniebot/src/erniebot/resources/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,18 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
"ernie-3.5": {
"model_id": "completions",
},
"ernie-3.5-8k": {
"model_id": "completions",
},
"ernie-turbo": {
"model_id": "eb-instant",
},
"ernie-4.0": {
"model_id": "completions_pro",
},
"ernie-longtext": {
"model_id": "ernie_bot_8k",
# ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11
"model_id": "completions",
},
"ernie-speed": {
"model_id": "ernie_speed",
Expand All @@ -84,14 +88,18 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
"ernie-3.5": {
"model_id": "completions",
},
"ernie-3.5-8k": {
"model_id": "completions",
},
"ernie-turbo": {
"model_id": "eb-instant",
},
"ernie-4.0": {
"model_id": "completions_pro",
},
"ernie-longtext": {
"model_id": "ernie_bot_8k",
# ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11
"model_id": "completions",
},
"ernie-speed": {
"model_id": "ernie_speed",
Expand All @@ -117,7 +125,7 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
"model_id": "completions_pro",
},
"ernie-longtext": {
"model_id": "ernie_bot_8k",
"model_id": "completions",
},
"ernie-speed": {
"model_id": "ernie_speed",
Expand All @@ -126,6 +134,11 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
},
}

@staticmethod
def check_models(model: str):
if model in ["ernie-longtext", "ernie-bot-8k"]:
logging.warning(f"{model} will be deprecated after 2024.4.11, so we will automatically map it to ernie-3.5-8k")

@overload
@classmethod
def create(
Expand Down Expand Up @@ -279,6 +292,8 @@ def create(
kwargs["headers"] = headers
if request_timeout is not None:
kwargs["request_timeout"] = request_timeout

cls.check_models(model)
resp = resource.create_resource(**kwargs)
return transform(ChatCompletionResponse.from_mapping, resp)

Expand Down Expand Up @@ -435,6 +450,8 @@ async def acreate(
kwargs["headers"] = headers
if request_timeout is not None:
kwargs["request_timeout"] = request_timeout

cls.check_models(model)
resp = await resource.acreate_resource(**kwargs)
return transform(ChatCompletionResponse.from_mapping, resp)

Expand Down

0 comments on commit 00246d6

Please sign in to comment.