Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add qianfan models #337

Merged
merged 9 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions erniebot/src/erniebot/intro.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ def list() -> List[Tuple[str, str]]:
("ernie-turbo", "文心大模型(ernie-turbo)"),
("ernie-4.0", "文心大模型(ernie-4.0)"),
("ernie-longtext", "文心大模型(ernie-longtext)"),
("ernie-speed", " 文心大模型(ernie-speed)"),
("ernie-speed-128k", " 文心大模型(ernie-speed-128k)"),
("ernie-tiny-8k", " 文心大模型(ernie-tiny-8k)"),
("ernie-char-8k", " 文心大模型(ernie-char-8k)"),
("ernie-text-embedding", "文心百中语义模型"),
("ernie-vilg-v2", "文心一格模型"),
]
74 changes: 60 additions & 14 deletions erniebot/src/erniebot/resources/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,18 +55,31 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
"ernie-3.5": {
"model_id": "completions",
},
"ernie-3.5-8k": {
"model_id": "completions",
},
"ernie-turbo": {
"model_id": "eb-instant",
},
"ernie-4.0": {
"model_id": "completions_pro",
},
"ernie-longtext": {
"model_id": "ernie_bot_8k",
# ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11
"model_id": "completions",
},
"ernie-speed": {
"model_id": "ernie_speed",
},
"ernie-speed-128k": {
"model_id": "ernie-speed-128k",
},
"ernie-tiny-8k": {
"model_id": "ernie-tiny-8k",
},
"ernie-char-8k": {
"model_id": "ernie-char-8k",
},
},
},
APIType.AISTUDIO: {
Expand All @@ -75,14 +88,30 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
"ernie-3.5": {
"model_id": "completions",
},
"ernie-3.5-8k": {
"model_id": "completions",
},
"ernie-turbo": {
"model_id": "eb-instant",
},
"ernie-4.0": {
"model_id": "completions_pro",
},
"ernie-longtext": {
"model_id": "ernie_bot_8k",
# ernie-longtext(ernie_bot_8k) will be deprecated in 2024.4.11
"model_id": "completions",
},
"ernie-speed": {
"model_id": "ernie_speed",
},
"ernie-speed-128k": {
"model_id": "ernie-speed-128k",
},
"ernie-tiny-8k": {
"model_id": "ernie-tiny-8k",
},
"ernie-char-8k": {
"model_id": "ernie-char-8k",
},
},
},
Expand All @@ -96,7 +125,7 @@ class ChatCompletion(EBResource, CreatableWithStreaming):
"model_id": "completions_pro",
},
"ernie-longtext": {
"model_id": "ernie_bot_8k",
"model_id": "completions",
},
"ernie-speed": {
"model_id": "ernie_speed",
Expand Down Expand Up @@ -258,6 +287,7 @@ def create(
kwargs["headers"] = headers
if request_timeout is not None:
kwargs["request_timeout"] = request_timeout

resp = resource.create_resource(**kwargs)
return transform(ChatCompletionResponse.from_mapping, resp)

Expand Down Expand Up @@ -414,9 +444,32 @@ async def acreate(
kwargs["headers"] = headers
if request_timeout is not None:
kwargs["request_timeout"] = request_timeout

resp = await resource.acreate_resource(**kwargs)
return transform(ChatCompletionResponse.from_mapping, resp)

def _check_model_kwargs(self, model_name: str, kwargs: Dict[str, Any]) -> None:
if model_name in ("ernie-turbo",):
for arg in (
"functions",
"stop",
"disable_search",
"enable_citation",
"tool_choice",
):
if arg in kwargs:
raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model_name} model.")

if model_name in ("ernie-speed", "ernie-speed-128k", "ernie-char-8k", "ernie-tiny-8k"):
for arg in (
"functions",
"disable_search",
"enable_citation",
"tool_choice",
):
if arg in kwargs:
raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model_name} model.")

def _prepare_create(self, kwargs: Dict[str, Any]) -> RequestWithStream:
def _update_model_name(given_name: str, old_name_to_new_name: Dict[str, str]) -> str:
if given_name in old_name_to_new_name:
Expand Down Expand Up @@ -468,7 +521,8 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
"ernie-bot": "ernie-3.5",
"ernie-bot-turbo": "ernie-turbo",
"ernie-bot-4": "ernie-4.0",
"ernie-bot-8k": "ernie-longtext",
"ernie-bot-8k": "ernie-3.5-8k",
"ernie-longtext": "ernie-3.5-8k",
},
)

Expand All @@ -490,16 +544,8 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:

# params
params = {}
if model in ("ernie-turbo", "ernie-speed"):
for arg in (
"functions",
"stop",
"disable_search",
"enable_citation",
"tool_choice",
):
if arg in kwargs:
raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model} model.")
self._check_model_kwargs(model, kwargs)

params["messages"] = messages
if "functions" in kwargs:
functions = kwargs["functions"]
Expand Down
Loading