77from __future__ import annotations
88
99import logging
10- from typing import Any
10+ from typing import Any , cast
1111
1212from pydantic ._internal ._model_construction import ModelMetaclass
1313
14-
15- # Provider constants imported from crewai.llm.constants
16- SUPPORTED_NATIVE_PROVIDERS : list [str ] = [
17- "openai" ,
18- "anthropic" ,
19- "claude" ,
20- "azure" ,
21- "azure_openai" ,
22- "google" ,
23- "gemini" ,
24- "bedrock" ,
25- "aws" ,
26- ]
14+ from crewai .llm .constants import (
15+ ANTHROPIC_MODELS ,
16+ AZURE_MODELS ,
17+ BEDROCK_MODELS ,
18+ GEMINI_MODELS ,
19+ OPENAI_MODELS ,
20+ SUPPORTED_NATIVE_PROVIDERS ,
21+ SupportedModels ,
22+ SupportedNativeProviders ,
23+ )
2724
2825
2926class LLMMeta (ModelMetaclass ):
@@ -49,25 +46,30 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805
4946 if cls .__name__ != "LLM" :
5047 return super ().__call__ (* args , ** kwargs )
5148
52- model = kwargs .get ("model" ) or (args [0 ] if args else None )
49+ model = cast (
50+ str | SupportedModels | None ,
51+ (kwargs .get ("model" ) or (args [0 ] if args else None )),
52+ )
5353 is_litellm = kwargs .get ("is_litellm" , False )
5454
5555 if not model or not isinstance (model , str ):
5656 raise ValueError ("Model must be a non-empty string" )
5757
5858 if args and not kwargs .get ("model" ):
59- kwargs ["model" ] = args [0 ]
60- args = args [1 :]
61- explicit_provider = kwargs .get ("provider" )
59+ kwargs ["model" ] = cast (SupportedModels , args [0 ])
60+ explicit_provider = cast (SupportedNativeProviders , kwargs .get ("provider" ))
6261
6362 if explicit_provider :
6463 provider = explicit_provider
6564 use_native = True
6665 model_string = model
6766 elif "/" in model :
68- prefix , _ , model_part = model .partition ("/" )
67+ prefix , _ , model_part = cast (
68+ tuple [SupportedNativeProviders , Any , SupportedModels ],
69+ model .partition ("/" ),
70+ )
6971
70- provider_mapping = {
72+ provider_mapping : dict [ str , SupportedNativeProviders ] = {
7173 "openai" : "openai" ,
7274 "anthropic" : "anthropic" ,
7375 "claude" : "anthropic" ,
@@ -122,7 +124,9 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805
122124 return super ().__call__ (model = model , is_litellm = True , ** kwargs_copy )
123125
124126 @staticmethod
125- def _validate_model_in_constants (model : str , provider : str ) -> bool :
127+ def _validate_model_in_constants (
128+ model : SupportedModels , provider : SupportedNativeProviders | None
129+ ) -> bool :
126130 """Validate if a model name exists in the provider's constants.
127131
128132 Args:
@@ -132,12 +136,6 @@ def _validate_model_in_constants(model: str, provider: str) -> bool:
132136 Returns:
133137 True if the model exists in the provider's constants, False otherwise
134138 """
135- from crewai .llm .constants import (
136- ANTHROPIC_MODELS ,
137- BEDROCK_MODELS ,
138- GEMINI_MODELS ,
139- OPENAI_MODELS ,
140- )
141139
142140 if provider == "openai" :
143141 return model in OPENAI_MODELS
@@ -158,7 +156,9 @@ def _validate_model_in_constants(model: str, provider: str) -> bool:
158156 return False
159157
160158 @staticmethod
161- def _infer_provider_from_model (model : str ) -> str :
159+ def _infer_provider_from_model (
160+ model : SupportedModels | str ,
161+ ) -> SupportedNativeProviders :
162162 """Infer the provider from the model name.
163163
164164 Args:
@@ -167,13 +167,6 @@ def _infer_provider_from_model(model: str) -> str:
167167 Returns:
168168 The inferred provider name, defaults to "openai"
169169 """
170- from crewai .llm .constants import (
171- ANTHROPIC_MODELS ,
172- AZURE_MODELS ,
173- BEDROCK_MODELS ,
174- GEMINI_MODELS ,
175- OPENAI_MODELS ,
176- )
177170
178171 if model in OPENAI_MODELS :
179172 return "openai"
@@ -193,7 +186,7 @@ def _infer_provider_from_model(model: str) -> str:
193186 return "openai"
194187
195188 @staticmethod
196- def _get_native_provider (provider : str ) -> type | None :
189+ def _get_native_provider (provider : SupportedNativeProviders | None ) -> type | None :
197190 """Get native provider class if available.
198191
199192 Args:
0 commit comments