Skip to content

Commit

Permalink
fix flake
Browse files Browse the repository at this point in the history
  • Loading branch information
Southpika committed Apr 29, 2024
1 parent a5c44b8 commit bf66cb8
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions erniebot/src/erniebot/resources/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def create(
if request_timeout is not None:
kwargs["request_timeout"] = request_timeout

resp = resource.create_resource(**kwargs)
resp = resource.create_resource(**kwargs)
return transform(ChatCompletionResponse.from_mapping, resp)

@overload
Expand Down Expand Up @@ -344,7 +344,8 @@ async def acreate(
validate_functions: bool = ...,
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ..., max_output_tokens: Optional[int] = ...,
request_timeout: Optional[float] = ...,
max_output_tokens: Optional[int] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> AsyncIterator["ChatCompletionResponse"]:
...
Expand All @@ -370,7 +371,8 @@ async def acreate(
validate_functions: bool = ...,
extra_params: Optional[dict] = ...,
headers: Optional[HeadersType] = ...,
request_timeout: Optional[float] = ..., max_output_tokens: Optional[int] = ...,
request_timeout: Optional[float] = ...,
max_output_tokens: Optional[int] = ...,
_config_: Optional[ConfigDictType] = ...,
) -> Union["ChatCompletionResponse", AsyncIterator["ChatCompletionResponse"]]:
...
Expand Down Expand Up @@ -458,12 +460,7 @@ async def acreate(

def _check_model_kwargs(self, model_name: str, kwargs: Dict[str, Any]) -> None:
if model_name in ("ernie-speed", "ernie-speed-128k", "ernie-char-8k", "ernie-tiny-8k", "ernie-lite"):
for arg in (
"functions",
"disable_search",
"enable_citation",
"tool_choice",
):
for arg in ("functions", "disable_search", "enable_citation", "tool_choice"):
if arg in kwargs:
raise errors.InvalidArgumentError(f"`{arg}` is not supported by the {model_name} model.")

Expand Down Expand Up @@ -500,7 +497,7 @@ def _set_val_if_key_exists(src: dict, dst: dict, key: str) -> None:
"extra_params",
"headers",
"request_timeout",
"max_output_tokens"
"max_output_tokens",
}

invalid_keys = kwargs.keys() - valid_keys
Expand Down

0 comments on commit bf66cb8

Please sign in to comment.