Skip to content

Commit 8b83bf3

Browse files
chore: remove duplication in azure client
1 parent 93f1fbd commit 8b83bf3

File tree

22 files changed

+370
-209
lines changed

22 files changed

+370
-209
lines changed

lib/crewai/src/crewai/llm/base_llm.py

Lines changed: 6 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
import logging
1313
import os
1414
import re
15-
from typing import TYPE_CHECKING, Any, ClassVar, Final
15+
from typing import TYPE_CHECKING, Any, Final
1616

17+
from dotenv import load_dotenv
1718
import httpx
18-
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
19+
from pydantic import BaseModel, Field, field_validator
1920

2021
from crewai.events.event_bus import crewai_event_bus
2122
from crewai.events.types.llm_events import (
@@ -42,6 +43,8 @@
4243
from crewai.utilities.types import LLMMessage
4344

4445

46+
load_dotenv()
47+
4548
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
4649
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
4750
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
@@ -65,10 +68,6 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
6568
stop: A list of stop sequences that the LLM should use to stop generation.
6669
"""
6770

68-
model_config: ClassVar[ConfigDict] = ConfigDict(
69-
extra="allow", populate_by_name=True
70-
)
71-
7271
# Core fields
7372
model: str = Field(..., description="The model identifier/name")
7473
temperature: float | None = Field(
@@ -100,7 +99,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
10099
"cached_prompt_tokens": 0,
101100
}
102101

103-
@field_validator("api_key", mode="before")
102+
@field_validator("api_key", mode="after")
104103
@classmethod
105104
def _validate_api_key(cls, value: str | None) -> str | None:
106105
"""Validate API key for authentication.
@@ -137,37 +136,6 @@ def _normalize_stop(cls, value: Any) -> list[str]:
137136
return value
138137
return []
139138

140-
@model_validator(mode="before")
141-
@classmethod
142-
def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]:
143-
"""Extract and normalize stop sequences before model initialization.
144-
145-
Args:
146-
values: Input values dictionary
147-
148-
Returns:
149-
Processed values dictionary
150-
"""
151-
if not values.get("model"):
152-
raise ValueError("Model name is required and cannot be empty")
153-
154-
stop = values.get("stop") or values.get("stop_sequences")
155-
if stop is None:
156-
values["stop"] = []
157-
elif isinstance(stop, str):
158-
values["stop"] = [stop]
159-
elif isinstance(stop, list):
160-
values["stop"] = stop
161-
else:
162-
values["stop"] = []
163-
164-
values.pop("stop_sequences", None)
165-
166-
if "provider" not in values or values["provider"] is None:
167-
values["provider"] = "openai"
168-
169-
return values
170-
171139
@property
172140
def additional_params(self) -> dict[str, Any]:
173141
"""Get additional parameters stored as extra fields.
@@ -190,20 +158,6 @@ def additional_params(self, value: dict[str, Any]) -> None:
190158
self.__pydantic_extra__ = {}
191159
self.__pydantic_extra__.update(value)
192160

193-
def model_post_init(self, __context: Any) -> None:
194-
"""Initialize token usage tracking after model initialization.
195-
196-
Args:
197-
__context: Pydantic context (unused)
198-
"""
199-
self._token_usage = {
200-
"total_tokens": 0,
201-
"prompt_tokens": 0,
202-
"completion_tokens": 0,
203-
"successful_requests": 0,
204-
"cached_prompt_tokens": 0,
205-
}
206-
207161
@abstractmethod
208162
def call(
209163
self,
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from crewai.llm.constants import SupportedNativeProviders
2+
3+
4+
PROVIDER_MAPPING: dict[str, SupportedNativeProviders] = {
5+
"openai": "openai",
6+
"anthropic": "anthropic",
7+
"claude": "anthropic",
8+
"azure": "azure",
9+
"azure_openai": "azure",
10+
"google": "gemini",
11+
"gemini": "gemini",
12+
"bedrock": "bedrock",
13+
"aws": "bedrock",
14+
}

lib/crewai/src/crewai/llm/internal/meta.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import logging
1010
from typing import Any, cast
1111

12+
from pydantic import ConfigDict
1213
from pydantic._internal._model_construction import ModelMetaclass
1314

1415
from crewai.llm.constants import (
@@ -21,6 +22,7 @@
2122
SupportedModels,
2223
SupportedNativeProviders,
2324
)
25+
from crewai.llm.internal.constants import PROVIDER_MAPPING
2426

2527

2628
class LLMMeta(ModelMetaclass):
@@ -30,6 +32,41 @@ class LLMMeta(ModelMetaclass):
3032
native provider implementation based on the model parameter.
3133
"""
3234

35+
def __new__(
36+
mcs,
37+
name: str,
38+
bases: tuple[type, ...],
39+
namespace: dict[str, Any],
40+
**kwargs: Any,
41+
) -> type:
42+
"""Create new LLM class with proper model_config for custom LLMs.
43+
44+
Args:
45+
name: Class name
46+
bases: Base classes
47+
namespace: Class namespace
48+
**kwargs: Additional arguments
49+
50+
Returns:
51+
New class
52+
"""
53+
if name != "BaseLLM" and any(
54+
base.__name__ in ("BaseLLM", "LLM") for base in bases
55+
):
56+
if "model_config" not in namespace:
57+
namespace["model_config"] = ConfigDict(
58+
extra="allow", populate_by_name=True
59+
)
60+
elif isinstance(namespace["model_config"], dict):
61+
config_dict = cast(
62+
ConfigDict, cast(object, dict(namespace["model_config"]))
63+
)
64+
config_dict.setdefault("extra", "allow")
65+
config_dict.setdefault("populate_by_name", True)
66+
namespace["model_config"] = ConfigDict(**config_dict)
67+
68+
return super().__new__(mcs, name, bases, namespace)
69+
3370
def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805
3471
"""Route to appropriate provider implementation at instantiation time.
3572
@@ -57,7 +94,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805
5794

5895
if args and not kwargs.get("model"):
5996
kwargs["model"] = cast(SupportedModels, args[0])
60-
args = args[1:]
97+
_ = args[1:]
6198
explicit_provider = cast(SupportedNativeProviders, kwargs.get("provider"))
6299

63100
if explicit_provider:
@@ -70,19 +107,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805
70107
model.partition("/"),
71108
)
72109

73-
provider_mapping: dict[str, SupportedNativeProviders] = {
74-
"openai": "openai",
75-
"anthropic": "anthropic",
76-
"claude": "anthropic",
77-
"azure": "azure",
78-
"azure_openai": "azure",
79-
"google": "gemini",
80-
"gemini": "gemini",
81-
"bedrock": "bedrock",
82-
"aws": "bedrock",
83-
}
84-
85-
canonical_provider = provider_mapping.get(prefix.lower())
110+
canonical_provider = PROVIDER_MAPPING.get(prefix.lower())
86111

87112
if canonical_provider and cls._validate_model_in_constants(
88113
model_part, canonical_provider

lib/crewai/src/crewai/llm/providers/anthropic/completion.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
from typing import TYPE_CHECKING, Any, cast
66

7+
from dotenv import load_dotenv
78
import httpx
89
from pydantic import BaseModel, Field, PrivateAttr, model_validator
910
from typing_extensions import Self
@@ -21,20 +22,24 @@
2122

2223

2324
if TYPE_CHECKING:
25+
from anthropic.types import Message
26+
2427
from crewai.agent.core import Agent
2528
from crewai.task import Task
2629

2730

2831
try:
2932
from anthropic import Anthropic
30-
from anthropic.types import Message
3133
from anthropic.types.tool_use_block import ToolUseBlock
3234
except ImportError:
3335
raise ImportError(
3436
'Anthropic native provider not available, to install: uv add "crewai[anthropic]"'
3537
) from None
3638

3739

40+
load_dotenv()
41+
42+
3843
class AnthropicCompletion(BaseLLM):
3944
"""Anthropic native completion implementation.
4045
@@ -66,19 +71,17 @@ class AnthropicCompletion(BaseLLM):
6671
top_p: float | None = Field(default=None, description="Nucleus sampling parameter")
6772
stream: bool = Field(default=False, description="Enable streaming responses")
6873
client_params: dict[str, Any] | None = Field(
69-
default=None, description="Additional Anthropic client parameters"
70-
)
71-
client: Anthropic = Field(
72-
default_factory=Anthropic, exclude=True, description="Anthropic client instance"
74+
default_factory=dict, description="Additional Anthropic client parameters"
7375
)
76+
_client: Anthropic = PrivateAttr(default=None) # type: ignore[assignment]
7477

7578
_is_claude_3: bool = PrivateAttr(default=False)
7679
_supports_tools: bool = PrivateAttr(default=False)
7780

7881
@model_validator(mode="after")
7982
def setup_client(self) -> Self:
8083
"""Initialize the Anthropic client and model-specific settings."""
81-
self.client = Anthropic(**self._get_client_params())
84+
self._client = Anthropic(**self._get_client_params())
8285

8386
self._is_claude_3 = "claude-3" in self.model.lower()
8487
self._supports_tools = self._is_claude_3
@@ -98,9 +101,6 @@ def supports_tools(self) -> bool:
98101
def _get_client_params(self) -> dict[str, Any]:
99102
"""Get client parameters."""
100103

101-
if self.api_key is None:
102-
raise ValueError("ANTHROPIC_API_KEY is required")
103-
104104
client_params = {
105105
"api_key": self.api_key,
106106
"base_url": self.base_url,
@@ -330,7 +330,7 @@ def _handle_completion(
330330
params["tool_choice"] = {"type": "tool", "name": "structured_output"}
331331

332332
try:
333-
response: Message = self.client.messages.create(**params)
333+
response: Message = self._client.messages.create(**params)
334334

335335
except Exception as e:
336336
if is_context_length_exceeded(e):
@@ -424,7 +424,7 @@ def _handle_streaming_completion(
424424
stream_params = {k: v for k, v in params.items() if k != "stream"}
425425

426426
# Make streaming API call
427-
with self.client.messages.stream(**stream_params) as stream:
427+
with self._client.messages.stream(**stream_params) as stream:
428428
for event in stream:
429429
if hasattr(event, "delta") and hasattr(event.delta, "text"):
430430
text_delta = event.delta.text
@@ -552,7 +552,7 @@ def _handle_tool_use_conversation(
552552

553553
try:
554554
# Send tool results back to Claude for final response
555-
final_response: Message = self.client.messages.create(**follow_up_params)
555+
final_response: Message = self._client.messages.create(**follow_up_params)
556556

557557
# Track token usage for follow-up call
558558
follow_up_usage = self._extract_anthropic_token_usage(final_response)

0 commit comments

Comments
 (0)