Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions litellm/proxy/hooks/parallel_request_limiter_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
cast,
)

from fastapi import HTTPException
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc quality pointed this was missing.


from litellm import DualCache
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_logger import CustomLogger
Expand Down
184 changes: 95 additions & 89 deletions litellm/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5694,6 +5694,8 @@ def _set_model_group_info( # noqa: PLR0915
total_tpm: Optional[int] = None
total_rpm: Optional[int] = None
configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None
# Use set for O(1) deduplication
providers_set: set = set()
model_list = self.get_model_list(model_name=model_group)
if model_list is None:
return None
Expand Down Expand Up @@ -5792,99 +5794,100 @@ def _set_model_group_info( # noqa: PLR0915
model_group_info = ModelGroupInfo( # type: ignore
**{
"model_group": user_facing_model_group_name,
"providers": [llm_provider],
"providers": [], # Will be set from providers_set after loop
**model_info,
}
)
else:
# if max_input_tokens > curr
# if max_output_tokens > curr
# if input_cost_per_token > curr
# if output_cost_per_token > curr
# supports_parallel_function_calling == True
# supports_vision == True
# supports_function_calling == True
if llm_provider not in model_group_info.providers:
model_group_info.providers.append(llm_provider)
if (
model_info.get("max_input_tokens", None) is not None
and model_info["max_input_tokens"] is not None
and (
model_group_info.max_input_tokens is None
or model_info["max_input_tokens"]
> model_group_info.max_input_tokens
)
):
model_group_info.max_input_tokens = model_info["max_input_tokens"]
if (
model_info.get("max_output_tokens", None) is not None
and model_info["max_output_tokens"] is not None
and (
model_group_info.max_output_tokens is None
or model_info["max_output_tokens"]
> model_group_info.max_output_tokens
)
):
model_group_info.max_output_tokens = model_info["max_output_tokens"]
if model_info.get("input_cost_per_token", None) is not None and (
model_group_info.input_cost_per_token is None
or model_info["input_cost_per_token"]
> model_group_info.input_cost_per_token
):
model_group_info.input_cost_per_token = model_info[
"input_cost_per_token"
]
if model_info.get("output_cost_per_token", None) is not None and (
model_group_info.output_cost_per_token is None
or model_info["output_cost_per_token"]
> model_group_info.output_cost_per_token
):
model_group_info.output_cost_per_token = model_info[
"output_cost_per_token"
]
if (
model_info.get("supports_parallel_function_calling", None)
is not None
and model_info["supports_parallel_function_calling"] is True # type: ignore
):
model_group_info.supports_parallel_function_calling = True
if (
model_info.get("supports_vision", None) is not None
and model_info["supports_vision"] is True # type: ignore
):
model_group_info.supports_vision = True
if (
model_info.get("supports_function_calling", None) is not None
and model_info["supports_function_calling"] is True # type: ignore
):
model_group_info.supports_function_calling = True
if (
model_info.get("supports_web_search", None) is not None
and model_info["supports_web_search"] is True # type: ignore
):
model_group_info.supports_web_search = True
if (
model_info.get("supports_url_context", None) is not None
and model_info["supports_url_context"] is True # type: ignore
):
model_group_info.supports_url_context = True

# Track provider in set for O(1) deduplication
providers_set.add(llm_provider)

# if max_input_tokens > curr
# if max_output_tokens > curr
# if input_cost_per_token > curr
# if output_cost_per_token > curr
# supports_parallel_function_calling == True
# supports_vision == True
# supports_function_calling == True
if (
model_info.get("max_input_tokens", None) is not None
and model_info["max_input_tokens"] is not None
and (
model_group_info.max_input_tokens is None
or model_info["max_input_tokens"]
> model_group_info.max_input_tokens
)
):
model_group_info.max_input_tokens = model_info["max_input_tokens"]
if (
model_info.get("max_output_tokens", None) is not None
and model_info["max_output_tokens"] is not None
and (
model_group_info.max_output_tokens is None
or model_info["max_output_tokens"]
> model_group_info.max_output_tokens
)
):
model_group_info.max_output_tokens = model_info["max_output_tokens"]
if model_info.get("input_cost_per_token", None) is not None and (
model_group_info.input_cost_per_token is None
or model_info["input_cost_per_token"]
> model_group_info.input_cost_per_token
):
model_group_info.input_cost_per_token = model_info[
"input_cost_per_token"
]
if model_info.get("output_cost_per_token", None) is not None and (
model_group_info.output_cost_per_token is None
or model_info["output_cost_per_token"]
> model_group_info.output_cost_per_token
):
model_group_info.output_cost_per_token = model_info[
"output_cost_per_token"
]
if (
model_info.get("supports_parallel_function_calling", None)
is not None
and model_info["supports_parallel_function_calling"] is True # type: ignore
):
model_group_info.supports_parallel_function_calling = True
if (
model_info.get("supports_vision", None) is not None
and model_info["supports_vision"] is True # type: ignore
):
model_group_info.supports_vision = True
if (
model_info.get("supports_function_calling", None) is not None
and model_info["supports_function_calling"] is True # type: ignore
):
model_group_info.supports_function_calling = True
if (
model_info.get("supports_web_search", None) is not None
and model_info["supports_web_search"] is True # type: ignore
):
model_group_info.supports_web_search = True
if (
model_info.get("supports_url_context", None) is not None
and model_info["supports_url_context"] is True # type: ignore
):
model_group_info.supports_url_context = True

if (
model_info.get("supports_reasoning", None) is not None
and model_info["supports_reasoning"] is True # type: ignore
):
model_group_info.supports_reasoning = True
if (
model_info.get("supported_openai_params", None) is not None
and model_info["supported_openai_params"] is not None
):
model_group_info.supported_openai_params = model_info[
"supported_openai_params"
]
if model_info.get("tpm", None) is not None and _deployment_tpm is None:
_deployment_tpm = model_info.get("tpm")
if model_info.get("rpm", None) is not None and _deployment_rpm is None:
_deployment_rpm = model_info.get("rpm")
if (
model_info.get("supports_reasoning", None) is not None
and model_info["supports_reasoning"] is True # type: ignore
):
model_group_info.supports_reasoning = True
if (
model_info.get("supported_openai_params", None) is not None
and model_info["supported_openai_params"] is not None
):
model_group_info.supported_openai_params = model_info[
"supported_openai_params"
]
if model_info.get("tpm", None) is not None and _deployment_tpm is None:
_deployment_tpm = model_info.get("tpm")
if model_info.get("rpm", None) is not None and _deployment_rpm is None:
_deployment_rpm = model_info.get("rpm")

if _deployment_tpm is not None:
if total_tpm is None:
Expand All @@ -5896,6 +5899,9 @@ def _set_model_group_info( # noqa: PLR0915
total_rpm = 0
total_rpm += _deployment_rpm # type: ignore
if model_group_info is not None:
## UPDATE WITH PROVIDERS FROM SET (convert to list)
model_group_info.providers = list(providers_set)

## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP
if total_tpm is not None:
model_group_info.tpm = total_tpm
Expand Down
Loading