diff --git a/erniebot/src/erniebot/intro.py b/erniebot/src/erniebot/intro.py index aece8f8c..e19fe490 100644 --- a/erniebot/src/erniebot/intro.py +++ b/erniebot/src/erniebot/intro.py @@ -28,6 +28,10 @@ def list() -> List[Tuple[str, str]]: ("ernie-turbo", "文心大模型(ernie-turbo)"), ("ernie-4.0", "文心大模型(ernie-4.0)"), ("ernie-longtext", "文心大模型(ernie-longtext)"), + ("ernie-speed", " 文心大模型(ernie-speed)"), + ("ernie-speed-128k", " 文心大模型(ernie-speed-128k)"), + ("ernie-tiny-8k", " 文心大模型(ernie-tiny-8k)"), + ("ernie-char-8k", " 文心大模型(ernie-char-8k)"), ("ernie-text-embedding", "文心百中语义模型"), ("ernie-vilg-v2", "文心一格模型"), ] diff --git a/erniebot/src/erniebot/resources/chat_completion.py b/erniebot/src/erniebot/resources/chat_completion.py index 605b0bd4..ca3317a5 100644 --- a/erniebot/src/erniebot/resources/chat_completion.py +++ b/erniebot/src/erniebot/resources/chat_completion.py @@ -55,6 +55,9 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "ernie-3.5": { "model_id": "completions", }, + "ernie-3.5-8k": { + "model_id": "completions", + }, "ernie-turbo": { "model_id": "eb-instant", }, @@ -62,11 +65,21 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "model_id": "completions_pro", }, "ernie-longtext": { - "model_id": "ernie_bot_8k", + # ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11 + "model_id": "completions", }, "ernie-speed": { "model_id": "ernie_speed", }, + "ernie-speed-128k": { + "model_id": "ernie-speed-128k", + }, + "ernie-tiny-8k": { + "model_id": "ernie-tiny-8k", + }, + "ernie-char-8k": { + "model_id": "ernie-char-8k", + }, }, }, APIType.AISTUDIO: { @@ -75,6 +88,9 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "ernie-3.5": { "model_id": "completions", }, + "ernie-3.5-8k": { + "model_id": "completions", + }, "ernie-turbo": { "model_id": "eb-instant", }, @@ -82,7 +98,20 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "model_id": "completions_pro", }, "ernie-longtext": { - "model_id": "ernie_bot_8k", + # ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11 + "model_id": "completions", + }, + "ernie-speed": { + "model_id": "ernie_speed", + }, + "ernie-speed-128k": { + "model_id": "ernie-speed-128k", + }, + "ernie-tiny-8k": { + "model_id": "ernie-tiny-8k", + }, + "ernie-char-8k": { + "model_id": "ernie-char-8k", }, }, }, @@ -96,7 +125,7 @@ class ChatCompletion(EBResource, CreatableWithStreaming): "model_id": "completions_pro", }, "ernie-longtext": { - "model_id": "ernie_bot_8k", + "model_id": "completions", }, "ernie-speed": { "model_id": "ernie_speed", @@ -258,6 +287,7 @@ def create( kwargs["headers"] = headers if request_timeout is not None: kwargs["request_timeout"] = request_timeout + resp = resource.create_resource(**kwargs) return transform(ChatCompletionResponse.from_mapping, resp) @@ -414,9 +444,32 @@ async def acreate( kwargs["headers"] = headers if request_timeout is not None: kwargs["request_timeout"] = request_timeout + resp = await resource.acreate_resource(**kwargs) return transform(ChatCompletionResponse.from_mapping, resp) + def _check_model_kwargs(self, model_name: str, kwargs: Dict[str, Any]) -> None: + if model_name in ("ernie-turbo",): + for arg in ( + "functions", + "stop", + "disable_search", + "enable_citation", + "tool_choice", + ): + if arg in kwargs: + raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model_name} model.") + + if model_name in ("ernie-speed", "ernie-speed-128k", "ernie-char-8k", "ernie-tiny-8k"): + for arg in ( + "functions", + "disable_search", + "enable_citation", + "tool_choice", + ): + if arg in kwargs: + raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model_name} model.") + def _prepare_create(self, kwargs: Dict[str, Any]) -> RequestWithStream: def _update_model_name(given_name: str, old_name_to_new_name: Dict[str, str]) -> str: if given_name in old_name_to_new_name: @@ -468,7 +521,8 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None: "ernie-bot": "ernie-3.5", "ernie-bot-turbo": "ernie-turbo", "ernie-bot-4": "ernie-4.0", - "ernie-bot-8k": "ernie-longtext", + "ernie-bot-8k": "ernie-3.5-8k", + "ernie-longtext": "ernie-3.5-8k", }, ) @@ -490,16 +544,8 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None: # params params = {} - if model in ("ernie-turbo", "ernie-speed"): - for arg in ( - "functions", - "stop", - "disable_search", - "enable_citation", - "tool_choice", - ): - if arg in kwargs: - raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model} model.") + self._check_model_kwargs(model, kwargs) + params["messages"] = messages if "functions" in kwargs: functions = kwargs["functions"]