diff --git a/litellm/proxy/hooks/parallel_request_limiter_v3.py b/litellm/proxy/hooks/parallel_request_limiter_v3.py index 0a49d7f67599..eda380b51650 100644 --- a/litellm/proxy/hooks/parallel_request_limiter_v3.py +++ b/litellm/proxy/hooks/parallel_request_limiter_v3.py @@ -20,6 +20,8 @@ cast, ) +from fastapi import HTTPException + from litellm import DualCache from litellm._logging import verbose_proxy_logger from litellm.integrations.custom_logger import CustomLogger diff --git a/litellm/router.py b/litellm/router.py index 9ae873164766..74e84b7502c7 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -5694,6 +5694,8 @@ def _set_model_group_info( # noqa: PLR0915 total_tpm: Optional[int] = None total_rpm: Optional[int] = None configurable_clientside_auth_params: CONFIGURABLE_CLIENTSIDE_AUTH_PARAMS = None + # Use set for O(1) deduplication + providers_set: set = set() model_list = self.get_model_list(model_name=model_group) if model_list is None: return None @@ -5792,99 +5794,100 @@ def _set_model_group_info( # noqa: PLR0915 model_group_info = ModelGroupInfo( # type: ignore **{ "model_group": user_facing_model_group_name, - "providers": [llm_provider], + "providers": [], # Will be set from providers_set after loop **model_info, } ) - else: - # if max_input_tokens > curr - # if max_output_tokens > curr - # if input_cost_per_token > curr - # if output_cost_per_token > curr - # supports_parallel_function_calling == True - # supports_vision == True - # supports_function_calling == True - if llm_provider not in model_group_info.providers: - model_group_info.providers.append(llm_provider) - if ( - model_info.get("max_input_tokens", None) is not None - and model_info["max_input_tokens"] is not None - and ( - model_group_info.max_input_tokens is None - or model_info["max_input_tokens"] - > model_group_info.max_input_tokens - ) - ): - model_group_info.max_input_tokens = model_info["max_input_tokens"] - if ( - model_info.get("max_output_tokens", None) is not None - and model_info["max_output_tokens"] is not None - and ( - model_group_info.max_output_tokens is None - or model_info["max_output_tokens"] - > model_group_info.max_output_tokens - ) - ): - model_group_info.max_output_tokens = model_info["max_output_tokens"] - if model_info.get("input_cost_per_token", None) is not None and ( - model_group_info.input_cost_per_token is None - or model_info["input_cost_per_token"] - > model_group_info.input_cost_per_token - ): - model_group_info.input_cost_per_token = model_info[ - "input_cost_per_token" - ] - if model_info.get("output_cost_per_token", None) is not None and ( - model_group_info.output_cost_per_token is None - or model_info["output_cost_per_token"] - > model_group_info.output_cost_per_token - ): - model_group_info.output_cost_per_token = model_info[ - "output_cost_per_token" - ] - if ( - model_info.get("supports_parallel_function_calling", None) - is not None - and model_info["supports_parallel_function_calling"] is True # type: ignore - ): - model_group_info.supports_parallel_function_calling = True - if ( - model_info.get("supports_vision", None) is not None - and model_info["supports_vision"] is True # type: ignore - ): - model_group_info.supports_vision = True - if ( - model_info.get("supports_function_calling", None) is not None - and model_info["supports_function_calling"] is True # type: ignore - ): - model_group_info.supports_function_calling = True - if ( - model_info.get("supports_web_search", None) is not None - and model_info["supports_web_search"] is True # type: ignore - ): - model_group_info.supports_web_search = True - if ( - model_info.get("supports_url_context", None) is not None - and model_info["supports_url_context"] is True # type: ignore - ): - model_group_info.supports_url_context = True + + # Track provider in set for O(1) deduplication + providers_set.add(llm_provider) + + # if max_input_tokens > curr + # if max_output_tokens > curr + # if input_cost_per_token > curr + # if output_cost_per_token > curr + # supports_parallel_function_calling == True + # supports_vision == True + # supports_function_calling == True + if ( + model_info.get("max_input_tokens", None) is not None + and model_info["max_input_tokens"] is not None + and ( + model_group_info.max_input_tokens is None + or model_info["max_input_tokens"] + > model_group_info.max_input_tokens + ) + ): + model_group_info.max_input_tokens = model_info["max_input_tokens"] + if ( + model_info.get("max_output_tokens", None) is not None + and model_info["max_output_tokens"] is not None + and ( + model_group_info.max_output_tokens is None + or model_info["max_output_tokens"] + > model_group_info.max_output_tokens + ) + ): + model_group_info.max_output_tokens = model_info["max_output_tokens"] + if model_info.get("input_cost_per_token", None) is not None and ( + model_group_info.input_cost_per_token is None + or model_info["input_cost_per_token"] + > model_group_info.input_cost_per_token + ): + model_group_info.input_cost_per_token = model_info[ + "input_cost_per_token" + ] + if model_info.get("output_cost_per_token", None) is not None and ( + model_group_info.output_cost_per_token is None + or model_info["output_cost_per_token"] + > model_group_info.output_cost_per_token + ): + model_group_info.output_cost_per_token = model_info[ + "output_cost_per_token" + ] + if ( + model_info.get("supports_parallel_function_calling", None) + is not None + and model_info["supports_parallel_function_calling"] is True # type: ignore + ): + model_group_info.supports_parallel_function_calling = True + if ( + model_info.get("supports_vision", None) is not None + and model_info["supports_vision"] is True # type: ignore + ): + model_group_info.supports_vision = True + if ( + model_info.get("supports_function_calling", None) is not None + and model_info["supports_function_calling"] is True # type: ignore + ): + model_group_info.supports_function_calling = True + if ( + model_info.get("supports_web_search", None) is not None + and model_info["supports_web_search"] is True # type: ignore + ): + model_group_info.supports_web_search = True + if ( + model_info.get("supports_url_context", None) is not None + and model_info["supports_url_context"] is True # type: ignore + ): + model_group_info.supports_url_context = True - if ( - model_info.get("supports_reasoning", None) is not None - and model_info["supports_reasoning"] is True # type: ignore - ): - model_group_info.supports_reasoning = True - if ( - model_info.get("supported_openai_params", None) is not None - and model_info["supported_openai_params"] is not None - ): - model_group_info.supported_openai_params = model_info[ - "supported_openai_params" - ] - if model_info.get("tpm", None) is not None and _deployment_tpm is None: - _deployment_tpm = model_info.get("tpm") - if model_info.get("rpm", None) is not None and _deployment_rpm is None: - _deployment_rpm = model_info.get("rpm") + if ( + model_info.get("supports_reasoning", None) is not None + and model_info["supports_reasoning"] is True # type: ignore + ): + model_group_info.supports_reasoning = True + if ( + model_info.get("supported_openai_params", None) is not None + and model_info["supported_openai_params"] is not None + ): + model_group_info.supported_openai_params = model_info[ + "supported_openai_params" + ] + if model_info.get("tpm", None) is not None and _deployment_tpm is None: + _deployment_tpm = model_info.get("tpm") + if model_info.get("rpm", None) is not None and _deployment_rpm is None: + _deployment_rpm = model_info.get("rpm") if _deployment_tpm is not None: if total_tpm is None: @@ -5896,6 +5899,9 @@ def _set_model_group_info( # noqa: PLR0915 total_rpm = 0 total_rpm += _deployment_rpm # type: ignore if model_group_info is not None: + ## UPDATE WITH PROVIDERS FROM SET (convert to list) + model_group_info.providers = list(providers_set) + ## UPDATE WITH TOTAL TPM/RPM FOR MODEL GROUP if total_tpm is not None: model_group_info.tpm = total_tpm diff --git a/tests/router_unit_tests/test_router_helper_utils.py b/tests/router_unit_tests/test_router_helper_utils.py index 094df944bcc7..603429468a6b 100644 --- a/tests/router_unit_tests/test_router_helper_utils.py +++ b/tests/router_unit_tests/test_router_helper_utils.py @@ -777,6 +777,57 @@ def test_set_model_group_info(model_list, user_facing_model_group_name): assert resp.model_group == user_facing_model_group_name +def test_set_model_group_info_providers_deduplication(): + """Test that providers_set correctly deduplicates providers in model group info + + This test verifies the optimization where a set is used for O(1) deduplication + instead of checking if provider exists in list before appending. + """ + # Create a model list with multiple deployments using the same provider (openai) + model_list_with_duplicates = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "test-key-1", + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "api_key": "test-key-2", + }, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "azure/gpt-35-turbo", + "api_key": "test-azure-key", + "api_base": "https://test.openai.azure.com/", + }, + }, + ] + + router = Router(model_list=model_list_with_duplicates) + resp = router._set_model_group_info( + model_group="gpt-3.5-turbo", + user_facing_model_group_name="gpt-3.5-turbo", + ) + + assert resp is not None + assert resp.model_group == "gpt-3.5-turbo" + + # Verify providers list contains unique providers only + assert resp.providers is not None + assert len(resp.providers) == 2 # openai and azure (deduplicated) + assert "openai" in resp.providers + assert "azure" in resp.providers + + # Verify no duplicates in the providers list + assert len(resp.providers) == len(set(resp.providers)) + + @pytest.mark.asyncio async def test_set_response_headers(model_list): """Test if the 'set_response_headers' function is working correctly"""