diff --git a/erniebot/src/erniebot/resources/chat_completion.py b/erniebot/src/erniebot/resources/chat_completion.py index e10c990e..5c9dd7ff 100644 --- a/erniebot/src/erniebot/resources/chat_completion.py +++ b/erniebot/src/erniebot/resources/chat_completion.py @@ -55,6 +55,9 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "ernie-3.5": { "model_id": "completions", }, + "ernie-3.5-8k": { + "model_id": "completions", + }, "ernie-turbo": { "model_id": "eb-instant", }, @@ -62,7 +65,8 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "model_id": "completions_pro", }, "ernie-longtext": { - "model_id": "ernie_bot_8k", + # ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11 + "model_id": "completions", }, "ernie-speed": { "model_id": "ernie_speed", @@ -84,6 +88,9 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "ernie-3.5": { "model_id": "completions", }, + "ernie-3.5-8k": { + "model_id": "completions", + }, "ernie-turbo": { "model_id": "eb-instant", }, @@ -91,7 +98,8 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "model_id": "completions_pro", }, "ernie-longtext": { - "model_id": "ernie_bot_8k", + # ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11 + "model_id": "completions", }, "ernie-speed": { "model_id": "ernie_speed", @@ -117,7 +125,7 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "model_id": "completions_pro", }, "ernie-longtext": { - "model_id": "ernie_bot_8k", + "model_id": "completions", }, "ernie-speed": { "model_id": "ernie_speed", @@ -126,6 +134,11 @@ class ChatCompletion(EBResource, CreatableWithStreaming): }, } + @staticmethod + def check_models(model: str): + if model in ["ernie-longtext", "ernie-bot-8k"]: + logging.warning(f"{model} will be deprecated after 2024.4.11, so we will automatically map it to ernie-3.5-8k") + @overload @classmethod def create( @@ -279,6 +292,8 @@ def create( kwargs["headers"] = headers if request_timeout is not None: kwargs["request_timeout"] = request_timeout + + cls.check_models(model) resp = resource.create_resource(**kwargs) return transform(ChatCompletionResponse.from_mapping, resp) @@ -435,6 +450,8 @@ async def acreate( kwargs["headers"] = headers if request_timeout is not None: kwargs["request_timeout"] = request_timeout + + cls.check_models(model) resp = await resource.acreate_resource(**kwargs) return transform(ChatCompletionResponse.from_mapping, resp)