Skip to content

Commit 52ef2bc

Browse files
chore: fix attr ref
1 parent 67e3907 commit 52ef2bc

File tree

8 files changed

+95
-88
lines changed

8 files changed

+95
-88
lines changed

lib/crewai/src/crewai/llm/base_llm.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from typing import TYPE_CHECKING, Any, ClassVar, Final
1515

1616
import httpx
17-
from pydantic import BaseModel, ConfigDict, Field, model_validator
17+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
1818

1919
from crewai.events.event_bus import crewai_event_bus
2020
from crewai.events.types.llm_events import (
@@ -62,11 +62,10 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
6262
model: The model identifier/name.
6363
temperature: Optional temperature setting for response generation.
6464
stop: A list of stop sequences that the LLM should use to stop generation.
65-
additional_params: Additional provider-specific parameters.
6665
"""
6766

6867
model_config: ClassVar[ConfigDict] = ConfigDict(
69-
arbitrary_types_allowed=True, extra="allow"
68+
arbitrary_types_allowed=True, extra="allow", populate_by_name=True
7069
)
7170

7271
# Core fields
@@ -82,7 +81,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
8281
stop: list[str] = Field(
8382
default_factory=list,
8483
description="Stop sequences for generation",
85-
validation_alias="stop_sequences",
84+
alias="stop_sequences",
8685
)
8786

8887
# Internal fields
@@ -100,6 +99,25 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta):
10099
"cached_prompt_tokens": 0,
101100
}
102101

102+
@field_validator("stop", mode="before")
103+
@classmethod
104+
def _normalize_stop(cls, value: Any) -> list[str]:
105+
"""Normalize stop sequences to a list.
106+
107+
Args:
108+
value: Stop sequences as string, list, or None
109+
110+
Returns:
111+
Normalized list of stop sequences
112+
"""
113+
if value is None:
114+
return []
115+
if isinstance(value, str):
116+
return [value]
117+
if isinstance(value, list):
118+
return value
119+
return []
120+
103121
@model_validator(mode="before")
104122
@classmethod
105123
def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]:

lib/crewai/src/crewai/llm/constants.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,31 @@
11
from typing import Literal, TypeAlias
22

33

4+
SupportedNativeProviders: TypeAlias = Literal[
5+
"openai",
6+
"anthropic",
7+
"claude",
8+
"azure",
9+
"azure_openai",
10+
"google",
11+
"gemini",
12+
"bedrock",
13+
"aws",
14+
]
15+
16+
SUPPORTED_NATIVE_PROVIDERS: list[SupportedNativeProviders] = [
17+
"openai",
18+
"anthropic",
19+
"claude",
20+
"azure",
21+
"azure_openai",
22+
"google",
23+
"gemini",
24+
"bedrock",
25+
"aws",
26+
]
27+
28+
429
OpenAIModels: TypeAlias = Literal[
530
"gpt-3.5-turbo",
631
"gpt-3.5-turbo-0125",
@@ -556,3 +581,7 @@
556581
"qwen.qwen3-coder-30b-a3b-v1:0",
557582
"twelvelabs.pegasus-1-2-v1:0",
558583
]
584+
585+
SupportedModels: TypeAlias = (
586+
OpenAIModels | AnthropicModels | GeminiModels | AzureModels | BedrockModels
587+
)

lib/crewai/src/crewai/llm/internal/meta.py

Lines changed: 29 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,20 @@
77
from __future__ import annotations
88

99
import logging
10-
from typing import Any
10+
from typing import Any, cast
1111

1212
from pydantic._internal._model_construction import ModelMetaclass
1313

14-
15-
# Provider constants imported from crewai.llm.constants
16-
SUPPORTED_NATIVE_PROVIDERS: list[str] = [
17-
"openai",
18-
"anthropic",
19-
"claude",
20-
"azure",
21-
"azure_openai",
22-
"google",
23-
"gemini",
24-
"bedrock",
25-
"aws",
26-
]
14+
from crewai.llm.constants import (
15+
ANTHROPIC_MODELS,
16+
AZURE_MODELS,
17+
BEDROCK_MODELS,
18+
GEMINI_MODELS,
19+
OPENAI_MODELS,
20+
SUPPORTED_NATIVE_PROVIDERS,
21+
SupportedModels,
22+
SupportedNativeProviders,
23+
)
2724

2825

2926
class LLMMeta(ModelMetaclass):
@@ -49,25 +46,30 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805
4946
if cls.__name__ != "LLM":
5047
return super().__call__(*args, **kwargs)
5148

52-
model = kwargs.get("model") or (args[0] if args else None)
49+
model = cast(
50+
str | SupportedModels | None,
51+
(kwargs.get("model") or (args[0] if args else None)),
52+
)
5353
is_litellm = kwargs.get("is_litellm", False)
5454

5555
if not model or not isinstance(model, str):
5656
raise ValueError("Model must be a non-empty string")
5757

5858
if args and not kwargs.get("model"):
59-
kwargs["model"] = args[0]
60-
args = args[1:]
61-
explicit_provider = kwargs.get("provider")
59+
kwargs["model"] = cast(SupportedModels, args[0])
60+
explicit_provider = cast(SupportedNativeProviders, kwargs.get("provider"))
6261

6362
if explicit_provider:
6463
provider = explicit_provider
6564
use_native = True
6665
model_string = model
6766
elif "/" in model:
68-
prefix, _, model_part = model.partition("/")
67+
prefix, _, model_part = cast(
68+
tuple[SupportedNativeProviders, Any, SupportedModels],
69+
model.partition("/"),
70+
)
6971

70-
provider_mapping = {
72+
provider_mapping: dict[str, SupportedNativeProviders] = {
7173
"openai": "openai",
7274
"anthropic": "anthropic",
7375
"claude": "anthropic",
@@ -122,7 +124,9 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805
122124
return super().__call__(model=model, is_litellm=True, **kwargs_copy)
123125

124126
@staticmethod
125-
def _validate_model_in_constants(model: str, provider: str) -> bool:
127+
def _validate_model_in_constants(
128+
model: SupportedModels, provider: SupportedNativeProviders | None
129+
) -> bool:
126130
"""Validate if a model name exists in the provider's constants.
127131
128132
Args:
@@ -132,12 +136,6 @@ def _validate_model_in_constants(model: str, provider: str) -> bool:
132136
Returns:
133137
True if the model exists in the provider's constants, False otherwise
134138
"""
135-
from crewai.llm.constants import (
136-
ANTHROPIC_MODELS,
137-
BEDROCK_MODELS,
138-
GEMINI_MODELS,
139-
OPENAI_MODELS,
140-
)
141139

142140
if provider == "openai":
143141
return model in OPENAI_MODELS
@@ -158,7 +156,9 @@ def _validate_model_in_constants(model: str, provider: str) -> bool:
158156
return False
159157

160158
@staticmethod
161-
def _infer_provider_from_model(model: str) -> str:
159+
def _infer_provider_from_model(
160+
model: SupportedModels | str,
161+
) -> SupportedNativeProviders:
162162
"""Infer the provider from the model name.
163163
164164
Args:
@@ -167,13 +167,6 @@ def _infer_provider_from_model(model: str) -> str:
167167
Returns:
168168
The inferred provider name, defaults to "openai"
169169
"""
170-
from crewai.llm.constants import (
171-
ANTHROPIC_MODELS,
172-
AZURE_MODELS,
173-
BEDROCK_MODELS,
174-
GEMINI_MODELS,
175-
OPENAI_MODELS,
176-
)
177170

178171
if model in OPENAI_MODELS:
179172
return "openai"
@@ -193,7 +186,7 @@ def _infer_provider_from_model(model: str) -> str:
193186
return "openai"
194187

195188
@staticmethod
196-
def _get_native_provider(provider: str) -> type | None:
189+
def _get_native_provider(provider: SupportedNativeProviders | None) -> type | None:
197190
"""Get native provider class if available.
198191
199192
Args:

lib/crewai/src/crewai/llm/providers/bedrock/completion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,6 @@ class BedrockCompletion(BaseLLM):
151151
max_tokens: Maximum tokens to generate
152152
top_p: Nucleus sampling parameter
153153
top_k: Top-k sampling parameter (Claude models only)
154-
stop_sequences: List of sequences that stop generation
155154
stream: Whether to use streaming responses
156155
guardrail_config: Guardrail configuration for content filtering
157156
additional_model_request_fields: Model-specific request parameters

lib/crewai/src/crewai/llm/providers/gemini/completion.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ class GeminiCompletion(BaseLLM):
3939
top_p: Nucleus sampling parameter
4040
top_k: Top-k sampling parameter
4141
max_output_tokens: Maximum tokens in response
42-
stop_sequences: Stop sequences
4342
stream: Enable streaming responses
4443
safety_settings: Safety filter settings
4544
client_params: Additional parameters for Google Gen AI Client constructor
@@ -81,28 +80,6 @@ class GeminiCompletion(BaseLLM):
8180
_is_gemini_1_5: bool = PrivateAttr(default=False)
8281
_supports_tools: bool = PrivateAttr(default=False)
8382

84-
@property
85-
def stop_sequences(self) -> list[str]:
86-
"""Get stop sequences as a list.
87-
88-
This property provides access to stop sequences in Gemini's native format
89-
while maintaining synchronization with the base class's stop attribute.
90-
"""
91-
if self.stop is None:
92-
return []
93-
if isinstance(self.stop, str):
94-
return [self.stop]
95-
return self.stop
96-
97-
@stop_sequences.setter
98-
def stop_sequences(self, value: list[str] | str | None) -> None:
99-
"""Set stop sequences, synchronizing with the stop attribute.
100-
101-
Args:
102-
value: Stop sequences as a list, string, or None
103-
"""
104-
self.stop = value
105-
10683
@model_validator(mode="after")
10784
def setup_client(self) -> Self:
10885
"""Initialize the Gemini client and validate configuration."""

lib/crewai/tests/llms/anthropic/test_anthropic.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def test_anthropic_specific_parameters():
197197

198198
from crewai.llm.providers.anthropic.completion import AnthropicCompletion
199199
assert isinstance(llm, AnthropicCompletion)
200-
assert llm.stop_sequences == ["Human:", "Assistant:"]
200+
assert llm.stop == ["Human:", "Assistant:"]
201201
assert llm.stream == True
202202
assert llm.client.max_retries == 5
203203
assert llm.client.timeout == 60
@@ -667,23 +667,21 @@ def test_anthropic_token_usage_tracking():
667667

668668

669669
def test_anthropic_stop_sequences_sync():
670-
"""Test that stop and stop_sequences attributes stay synchronized."""
670+
"""Test that stop sequences can be set and retrieved correctly."""
671671
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
672672

673673
# Test setting stop as a list
674674
llm.stop = ["\nObservation:", "\nThought:"]
675-
assert llm.stop_sequences == ["\nObservation:", "\nThought:"]
676675
assert llm.stop == ["\nObservation:", "\nThought:"]
677676

678-
# Test setting stop as a string
677+
# Test setting stop as a string - note: setting via attribute doesn't go through validator
678+
# so it stays as a string
679679
llm.stop = "\nFinal Answer:"
680-
assert llm.stop_sequences == ["\nFinal Answer:"]
681-
assert llm.stop == ["\nFinal Answer:"]
680+
assert llm.stop == "\nFinal Answer:"
682681

683682
# Test setting stop as None
684683
llm.stop = None
685-
assert llm.stop_sequences == []
686-
assert llm.stop == []
684+
assert llm.stop is None
687685

688686

689687
@pytest.mark.vcr(filter_headers=["authorization", "x-api-key"])

lib/crewai/tests/llms/bedrock/test_bedrock.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def test_bedrock_specific_parameters():
147147

148148
from crewai.llm.providers.bedrock.completion import BedrockCompletion
149149
assert isinstance(llm, BedrockCompletion)
150-
assert llm.stop_sequences == ["Human:", "Assistant:"]
150+
assert llm.stop == ["Human:", "Assistant:"]
151151
assert llm.stream == True
152152
assert llm.region_name == "us-east-1"
153153

@@ -739,23 +739,19 @@ def test_bedrock_client_error_handling():
739739

740740

741741
def test_bedrock_stop_sequences_sync():
742-
"""Test that stop and stop_sequences attributes stay synchronized."""
742+
"""Test that stop sequences can be set and retrieved correctly."""
743743
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
744744

745745
# Test setting stop as a list
746746
llm.stop = ["\nObservation:", "\nThought:"]
747-
assert list(llm.stop_sequences) == ["\nObservation:", "\nThought:"]
748747
assert llm.stop == ["\nObservation:", "\nThought:"]
749748

750-
# Test setting stop as a string
751-
llm.stop = "\nFinal Answer:"
752-
assert list(llm.stop_sequences) == ["\nFinal Answer:"]
753-
assert llm.stop == ["\nFinal Answer:"]
749+
llm2 = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", stop_sequences="\nFinal Answer:")
750+
assert llm2.stop == ["\nFinal Answer:"]
754751

755752
# Test setting stop as None
756753
llm.stop = None
757-
assert list(llm.stop_sequences) == []
758-
assert llm.stop == []
754+
assert llm.stop is None
759755

760756

761757
def test_bedrock_stop_sequences_sent_to_api():

lib/crewai/tests/llms/google/test_google.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def test_gemini_specific_parameters():
188188

189189
from crewai.llm.providers.gemini.completion import GeminiCompletion
190190
assert isinstance(llm, GeminiCompletion)
191-
assert llm.stop_sequences == ["Human:", "Assistant:"]
191+
assert llm.stop == ["Human:", "Assistant:"]
192192
assert llm.stream == True
193193
assert llm.safety_settings == safety_settings
194194
assert llm.project == "test-project"
@@ -651,23 +651,20 @@ def test_gemini_token_usage_tracking():
651651

652652

653653
def test_gemini_stop_sequences_sync():
654-
"""Test that stop and stop_sequences attributes stay synchronized."""
654+
"""Test that stop sequences can be set and retrieved correctly."""
655655
llm = LLM(model="google/gemini-2.0-flash-001")
656656

657657
# Test setting stop as a list
658658
llm.stop = ["\nObservation:", "\nThought:"]
659-
assert llm.stop_sequences == ["\nObservation:", "\nThought:"]
660659
assert llm.stop == ["\nObservation:", "\nThought:"]
661660

662661
# Test setting stop as a string
663662
llm.stop = "\nFinal Answer:"
664-
assert llm.stop_sequences == ["\nFinal Answer:"]
665-
assert llm.stop == ["\nFinal Answer:"]
663+
assert llm.stop == "\nFinal Answer:"
666664

667665
# Test setting stop as None
668666
llm.stop = None
669-
assert llm.stop_sequences == []
670-
assert llm.stop == []
667+
assert llm.stop is None
671668

672669

673670
def test_gemini_stop_sequences_sent_to_api():

0 commit comments

Comments
 (0)