From 46785adf586222b279337dad5b97e96dfdb23f71 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Mon, 10 Nov 2025 14:22:09 -0500 Subject: [PATCH 1/7] chore: refactor llms to base models --- docs/en/concepts/llms.mdx | 2 +- lib/crewai/src/crewai/__init__.py | 2 +- lib/crewai/src/crewai/agent/core.py | 37 ++- .../crewai/agents/agent_builder/base_agent.py | 6 +- .../src/crewai/agents/crew_agent_executor.py | 2 +- lib/crewai/src/crewai/crew.py | 2 +- lib/crewai/src/crewai/lite_agent.py | 7 +- lib/crewai/src/crewai/llm/__init__.py | 4 + .../src/crewai/{llms => llm}/base_llm.py | 133 +++++--- .../src/crewai/{llms => llm}/constants.py | 0 lib/crewai/src/crewai/{llm.py => llm/core.py} | 294 +++--------------- .../crewai/{llms => llm}/hooks/__init__.py | 2 +- .../src/crewai/{llms => llm}/hooks/base.py | 0 .../crewai/{llms => llm}/hooks/transport.py | 2 +- .../providers => llm/internal}/__init__.py | 0 lib/crewai/src/crewai/llm/internal/meta.py | 232 ++++++++++++++ .../anthropic => llm/providers}/__init__.py | 0 .../providers/anthropic}/__init__.py | 0 .../providers/anthropic/completion.py | 64 ++-- .../providers/azure}/__init__.py | 0 .../providers/azure/completion.py | 9 +- .../providers/bedrock}/__init__.py | 0 .../providers/bedrock/completion.py | 60 ++-- .../providers/gemini}/__init__.py | 0 .../providers/gemini/completion.py | 61 ++-- .../providers/openai}/__init__.py | 0 .../providers/openai/completion.py | 136 ++++---- .../crewai/llm/providers/utils/__init__.py | 0 .../{llms => llm}/providers/utils/common.py | 0 .../src/crewai/llms/third_party/__init__.py | 1 - lib/crewai/src/crewai/tasks/llm_guardrail.py | 2 +- .../src/crewai/utilities/agent_utils.py | 2 +- lib/crewai/src/crewai/utilities/converter.py | 2 +- .../evaluators/crew_evaluator_handler.py | 2 +- .../crewai/utilities/internal_instructor.py | 2 +- lib/crewai/src/crewai/utilities/llm_utils.py | 2 +- .../src/crewai/utilities/planning_handler.py | 2 +- lib/crewai/src/crewai/utilities/tool_utils.py | 2 +- lib/crewai/tests/agents/test_agent.py | 2 +- lib/crewai/tests/agents/test_lite_agent.py | 2 +- .../test_agent_function_calling_llm.yaml | 16 +- ...rmat_after_using_tools_too_many_times.yaml | 30 +- ...respect_the_max_rpm_set_over_crew_rpm.yaml | 12 +- ...r_context_for_first_task_hierarchical.yaml | 24 +- ...est_sets_parent_flow_when_inside_flow.yaml | 16 +- ...wer_is_the_final_answer_for_the_agent.yaml | 6 +- .../tests/llms/anthropic/test_anthropic.py | 22 +- lib/crewai/tests/llms/azure/test_azure.py | 48 +-- lib/crewai/tests/llms/bedrock/test_bedrock.py | 16 +- lib/crewai/tests/llms/google/test_google.py | 22 +- .../llms/hooks/test_anthropic_interceptor.py | 2 +- .../tests/llms/hooks/test_base_interceptor.py | 2 +- .../llms/hooks/test_openai_interceptor.py | 2 +- lib/crewai/tests/llms/hooks/test_transport.py | 4 +- .../llms/hooks/test_unsupported_providers.py | 2 +- lib/crewai/tests/llms/openai/test_openai.py | 10 +- lib/crewai/tests/test_custom_llm.py | 2 +- lib/crewai/tests/utilities/test_events.py | 2 +- lib/crewai/tests/utilities/test_llm_utils.py | 2 +- logs.txt | 20 ++ 60 files changed, 715 insertions(+), 621 deletions(-) create mode 100644 lib/crewai/src/crewai/llm/__init__.py rename lib/crewai/src/crewai/{llms => llm}/base_llm.py (84%) rename lib/crewai/src/crewai/{llms => llm}/constants.py (100%) rename lib/crewai/src/crewai/{llm.py => llm/core.py} (86%) rename lib/crewai/src/crewai/{llms => llm}/hooks/__init__.py (58%) rename lib/crewai/src/crewai/{llms => llm}/hooks/base.py (100%) rename lib/crewai/src/crewai/{llms => llm}/hooks/transport.py (98%) rename lib/crewai/src/crewai/{llms/providers => llm/internal}/__init__.py (100%) create mode 100644 lib/crewai/src/crewai/llm/internal/meta.py rename lib/crewai/src/crewai/{llms/providers/anthropic => llm/providers}/__init__.py (100%) rename lib/crewai/src/crewai/{llms/providers/azure => llm/providers/anthropic}/__init__.py (100%) rename lib/crewai/src/crewai/{llms => llm}/providers/anthropic/completion.py (94%) rename lib/crewai/src/crewai/{llms/providers/bedrock => llm/providers/azure}/__init__.py (100%) rename lib/crewai/src/crewai/{llms => llm}/providers/azure/completion.py (98%) rename lib/crewai/src/crewai/{llms/providers/gemini => llm/providers/bedrock}/__init__.py (100%) rename lib/crewai/src/crewai/{llms => llm}/providers/bedrock/completion.py (96%) rename lib/crewai/src/crewai/{llms/providers/openai => llm/providers/gemini}/__init__.py (100%) rename lib/crewai/src/crewai/{llms => llm}/providers/gemini/completion.py (94%) rename lib/crewai/src/crewai/{llms/providers/utils => llm/providers/openai}/__init__.py (100%) rename lib/crewai/src/crewai/{llms => llm}/providers/openai/completion.py (87%) create mode 100644 lib/crewai/src/crewai/llm/providers/utils/__init__.py rename lib/crewai/src/crewai/{llms => llm}/providers/utils/common.py (100%) delete mode 100644 lib/crewai/src/crewai/llms/third_party/__init__.py create mode 100644 logs.txt diff --git a/docs/en/concepts/llms.mdx b/docs/en/concepts/llms.mdx index fabf27aaa9..fa080bc3e2 100644 --- a/docs/en/concepts/llms.mdx +++ b/docs/en/concepts/llms.mdx @@ -1212,7 +1212,7 @@ Learn how to get the most out of your LLM configuration: ```python import httpx from crewai import LLM -from crewai.llms.hooks import BaseInterceptor +from crewai.llm.hooks import BaseInterceptor class CustomInterceptor(BaseInterceptor[httpx.Request, httpx.Response]): """Custom interceptor to modify requests and responses.""" diff --git a/lib/crewai/src/crewai/__init__.py b/lib/crewai/src/crewai/__init__.py index c992f11f71..ef2bcf78d2 100644 --- a/lib/crewai/src/crewai/__init__.py +++ b/lib/crewai/src/crewai/__init__.py @@ -9,7 +9,7 @@ from crewai.flow.flow import Flow from crewai.knowledge.knowledge import Knowledge from crewai.llm import LLM -from crewai.llms.base_llm import BaseLLM +from crewai.llm.base_llm import BaseLLM from crewai.process import Process from crewai.task import Task from crewai.tasks.llm_guardrail import LLMGuardrail diff --git a/lib/crewai/src/crewai/agent/core.py b/lib/crewai/src/crewai/agent/core.py index 1d94c4d19a..7c2b96f71f 100644 --- a/lib/crewai/src/crewai/agent/core.py +++ b/lib/crewai/src/crewai/agent/core.py @@ -39,7 +39,7 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context from crewai.lite_agent import LiteAgent -from crewai.llms.base_llm import BaseLLM +from crewai.llm.base_llm import BaseLLM from crewai.mcp import ( MCPClient, MCPServerConfig, @@ -626,7 +626,7 @@ def create_agent_executor( ) self.agent_executor = CrewAgentExecutor( - llm=self.llm, + llm=self.llm, # type: ignore[arg-type] task=task, # type: ignore[arg-type] agent=self, crew=self.crew, @@ -803,6 +803,7 @@ def _get_native_mcp_tools( from crewai.tools.base_tool import BaseTool from crewai.tools.mcp_native_tool import MCPNativeTool + transport: StdioTransport | HTTPTransport | SSETransport if isinstance(mcp_config, MCPServerStdio): transport = StdioTransport( command=mcp_config.command, @@ -896,10 +897,12 @@ async def _setup_client_and_list_tools() -> list[dict[str, Any]]: server_name=server_name, run_context=None, ) - if mcp_config.tool_filter(context, tool): + # Try new signature first + if mcp_config.tool_filter(context, tool): # type: ignore[arg-type,call-arg] filtered_tools.append(tool) except (TypeError, AttributeError): - if mcp_config.tool_filter(tool): + # Fallback to old signature + if mcp_config.tool_filter(tool): # type: ignore[arg-type,call-arg] filtered_tools.append(tool) else: # Not callable - include tool @@ -974,7 +977,9 @@ def _extract_server_name(server_url: str) -> str: path = parsed.path.replace("/", "_").strip("_") return f"{domain}_{path}" if path else domain - def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]: + def _get_mcp_tool_schemas( + self, server_params: dict[str, Any] + ) -> dict[str, dict[str, Any]]: """Get tool schemas from MCP server for wrapper creation with caching.""" server_url = server_params["url"] @@ -988,7 +993,7 @@ def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]: self._logger.log( "debug", f"Using cached MCP tool schemas for {server_url}" ) - return cached_data + return cast(dict[str, dict[str, Any]], cached_data) try: schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params)) @@ -1006,7 +1011,7 @@ def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]: async def _get_mcp_tool_schemas_async( self, server_params: dict[str, Any] - ) -> dict[str, dict]: + ) -> dict[str, dict[str, Any]]: """Async implementation of MCP tool schema retrieval with timeouts and retries.""" server_url = server_params["url"] return await self._retry_mcp_discovery( @@ -1014,7 +1019,7 @@ async def _get_mcp_tool_schemas_async( ) async def _retry_mcp_discovery( - self, operation_func, server_url: str + self, operation_func: Any, server_url: str ) -> dict[str, dict[str, Any]]: """Retry MCP discovery operation with exponential backoff, avoiding try-except in loop.""" last_error = None @@ -1045,7 +1050,7 @@ async def _retry_mcp_discovery( @staticmethod async def _attempt_mcp_discovery( - operation_func, server_url: str + operation_func: Any, server_url: str ) -> tuple[dict[str, dict[str, Any]] | None, str, bool]: """Attempt single MCP discovery operation and return (result, error_message, should_retry).""" try: @@ -1149,13 +1154,13 @@ def _json_schema_to_pydantic( Field(..., description=field_description), ) else: - field_definitions[field_name] = ( + field_definitions[field_name] = ( # type: ignore[assignment] field_type | None, Field(default=None, description=field_description), ) model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema" - return create_model(model_name, **field_definitions) + return create_model(model_name, **field_definitions) # type: ignore[no-any-return,call-overload] def _json_type_to_python(self, field_schema: dict[str, Any]) -> type: """Convert JSON Schema type to Python type. @@ -1175,16 +1180,16 @@ def _json_type_to_python(self, field_schema: dict[str, Any]) -> type: if "const" in option: types.append(str) else: - types.append(self._json_type_to_python(option)) + types.append(self._json_type_to_python(option)) # type: ignore[arg-type] unique_types = list(set(types)) if len(unique_types) > 1: result = unique_types[0] for t in unique_types[1:]: - result = result | t + result = result | t # type: ignore[assignment] return result return unique_types[0] - type_mapping = { + type_mapping: dict[str, type] = { "string": str, "number": float, "integer": int, @@ -1193,10 +1198,10 @@ def _json_type_to_python(self, field_schema: dict[str, Any]) -> type: "object": dict, } - return type_mapping.get(json_type, Any) + return type_mapping.get(json_type or "", Any) @staticmethod - def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]: + def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]: """Fetch MCP server configurations from CrewAI AMP API.""" # TODO: Implement AMP API call to "integrations/mcps" endpoint # Should return list of server configs with URLs diff --git a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py index 932c98611f..33d137220c 100644 --- a/lib/crewai/src/crewai/agents/agent_builder/base_agent.py +++ b/lib/crewai/src/crewai/agents/agent_builder/base_agent.py @@ -137,7 +137,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): default=False, description="Enable agent to delegate and ask questions among each other.", ) - tools: list[BaseTool] | None = Field( + tools: list[BaseTool] = Field( default_factory=list, description="Tools at agents' disposal" ) max_iter: int = Field( @@ -161,7 +161,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta): description="An instance of the ToolsHandler class.", ) tools_results: list[dict[str, Any]] = Field( - default=[], description="Results of the tools used by the agent." + default_factory=list, description="Results of the tools used by the agent." ) max_tokens: int | None = Field( default=None, description="Maximum number of tokens for the agent's execution." @@ -265,7 +265,7 @@ def validate_mcps( if not mcps: return mcps - validated_mcps = [] + validated_mcps: list[str | MCPServerConfig] = [] for mcp in mcps: if isinstance(mcp, str): if mcp.startswith(("https://", "crewai-amp:")): diff --git a/lib/crewai/src/crewai/agents/crew_agent_executor.py b/lib/crewai/src/crewai/agents/crew_agent_executor.py index 5b806658ca..8cf3c1e12c 100644 --- a/lib/crewai/src/crewai/agents/crew_agent_executor.py +++ b/lib/crewai/src/crewai/agents/crew_agent_executor.py @@ -47,7 +47,7 @@ from crewai.agent import Agent from crewai.agents.tools_handler import ToolsHandler from crewai.crew import Crew - from crewai.llms.base_llm import BaseLLM + from crewai.llm.base_llm import BaseLLM from crewai.task import Task from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index f5af4a426a..31eb7466ca 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -57,7 +57,7 @@ from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.llm import LLM -from crewai.llms.base_llm import BaseLLM +from crewai.llm.base_llm import BaseLLM from crewai.memory.entity.entity_memory import EntityMemory from crewai.memory.external.external_memory import ExternalMemory from crewai.memory.long_term.long_term_memory import LongTermMemory diff --git a/lib/crewai/src/crewai/lite_agent.py b/lib/crewai/src/crewai/lite_agent.py index 4314e900e8..ef877e01b6 100644 --- a/lib/crewai/src/crewai/lite_agent.py +++ b/lib/crewai/src/crewai/lite_agent.py @@ -40,7 +40,7 @@ from crewai.flow.flow_trackable import FlowTrackable from crewai.lite_agent_output import LiteAgentOutput from crewai.llm import LLM -from crewai.llms.base_llm import BaseLLM +from crewai.llm.base_llm import BaseLLM from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool from crewai.utilities.agent_utils import ( @@ -503,7 +503,7 @@ def _invoke_loop(self) -> AgentFinish: AgentFinish: The final result of the agent execution. """ # Execute the agent loop - formatted_answer = None + formatted_answer: AgentAction | AgentFinish | None = None while not isinstance(formatted_answer, AgentFinish): try: if has_reached_max_iterations(self._iterations, self.max_iterations): @@ -551,7 +551,8 @@ def _invoke_loop(self) -> AgentFinish: show_logs=self._show_logs, ) - self._append_message(formatted_answer.text, role="assistant") + if formatted_answer is not None: + self._append_message(formatted_answer.text, role="assistant") except OutputParserError as e: # noqa: PERF203 self._printer.print( content="Failed to parse LLM output. Retrying...", diff --git a/lib/crewai/src/crewai/llm/__init__.py b/lib/crewai/src/crewai/llm/__init__.py new file mode 100644 index 0000000000..57cb57069c --- /dev/null +++ b/lib/crewai/src/crewai/llm/__init__.py @@ -0,0 +1,4 @@ +from crewai.llm.core import LLM + + +__all__ = ["LLM"] diff --git a/lib/crewai/src/crewai/llms/base_llm.py b/lib/crewai/src/crewai/llm/base_llm.py similarity index 84% rename from lib/crewai/src/crewai/llms/base_llm.py rename to lib/crewai/src/crewai/llm/base_llm.py index a7026c5c53..1222763457 100644 --- a/lib/crewai/src/crewai/llms/base_llm.py +++ b/lib/crewai/src/crewai/llm/base_llm.py @@ -11,9 +11,10 @@ import json import logging import re -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, ClassVar, Final -from pydantic import BaseModel +import httpx +from pydantic import BaseModel, ConfigDict, Field, model_validator from crewai.events.event_bus import crewai_event_bus from crewai.events.types.llm_events import ( @@ -28,6 +29,8 @@ ToolUsageFinishedEvent, ToolUsageStartedEvent, ) +from crewai.llm.hooks.base import BaseInterceptor +from crewai.llm.internal.meta import LLMMeta from crewai.types.usage_metrics import UsageMetrics @@ -43,7 +46,7 @@ _JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL) -class BaseLLM(ABC): +class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): """Abstract base class for LLM implementations. This class defines the interface that all LLM implementations must follow. @@ -62,46 +65,96 @@ class BaseLLM(ABC): additional_params: Additional provider-specific parameters. """ - is_litellm: bool = False - - def __init__( - self, - model: str, - temperature: float | None = None, - api_key: str | None = None, - base_url: str | None = None, - provider: str | None = None, - **kwargs: Any, - ) -> None: - """Initialize the BaseLLM with default attributes. + model_config: ClassVar[ConfigDict] = ConfigDict( + arbitrary_types_allowed=True, extra="allow", validate_assignment=True + ) + + # Core fields + model: str = Field(..., description="The model identifier/name") + temperature: float | None = Field( + None, description="Temperature setting for response generation" + ) + api_key: str | None = Field(None, description="API key for authentication") + base_url: str | None = Field(None, description="Base URL for API requests") + provider: str = Field( + default="openai", description="Provider name (openai, anthropic, etc.)" + ) + stop: list[str] = Field( + default_factory=list, description="Stop sequences for generation" + ) + + # Internal fields + is_litellm: bool = Field( + default=False, description="Whether this instance uses LiteLLM" + ) + interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = Field( + None, description="HTTP request/response interceptor" + ) + _token_usage: dict[str, int] = { + "total_tokens": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + "successful_requests": 0, + "cached_prompt_tokens": 0, + } + + @model_validator(mode="before") + @classmethod + def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]: + """Extract and normalize stop sequences before model initialization. Args: - model: The model identifier/name. - temperature: Optional temperature setting for response generation. - stop: Optional list of stop sequences for generation. - **kwargs: Additional provider-specific parameters. + values: Input values dictionary + + Returns: + Processed values dictionary """ - if not model: + if not values.get("model"): raise ValueError("Model name is required and cannot be empty") - self.model = model - self.temperature = temperature - self.api_key = api_key - self.base_url = base_url - # Store additional parameters for provider-specific use - self.additional_params = kwargs - self._provider = provider or "openai" - - stop = kwargs.pop("stop", None) + # Handle stop sequences + stop = values.get("stop") if stop is None: - self.stop: list[str] = [] + values["stop"] = [] elif isinstance(stop, str): - self.stop = [stop] - elif isinstance(stop, list): - self.stop = stop - else: - self.stop = [] + values["stop"] = [stop] + elif not isinstance(stop, list): + values["stop"] = [] + + # Set default provider if not specified + if "provider" not in values or values["provider"] is None: + values["provider"] = "openai" + + return values + + @property + def additional_params(self) -> dict[str, Any]: + """Get additional parameters stored as extra fields. + + Returns: + Dictionary of additional parameters + """ + return self.__pydantic_extra__ or {} + + @additional_params.setter + def additional_params(self, value: dict[str, Any]) -> None: + """Set additional parameters as extra fields. + + Args: + value: Dictionary of additional parameters to set + """ + if not isinstance(value, dict): + raise ValueError("additional_params must be a dictionary") + if self.__pydantic_extra__ is None: + self.__pydantic_extra__ = {} + self.__pydantic_extra__.update(value) + def model_post_init(self, __context: Any) -> None: + """Initialize token usage tracking after model initialization. + + Args: + __context: Pydantic context (unused) + """ self._token_usage = { "total_tokens": 0, "prompt_tokens": 0, @@ -110,16 +163,6 @@ def __init__( "cached_prompt_tokens": 0, } - @property - def provider(self) -> str: - """Get the provider of the LLM.""" - return self._provider - - @provider.setter - def provider(self, value: str) -> None: - """Set the provider of the LLM.""" - self._provider = value - @abstractmethod def call( self, diff --git a/lib/crewai/src/crewai/llms/constants.py b/lib/crewai/src/crewai/llm/constants.py similarity index 100% rename from lib/crewai/src/crewai/llms/constants.py rename to lib/crewai/src/crewai/llm/constants.py diff --git a/lib/crewai/src/crewai/llm.py b/lib/crewai/src/crewai/llm/core.py similarity index 86% rename from lib/crewai/src/crewai/llm.py rename to lib/crewai/src/crewai/llm/core.py index b0cf420917..a386e2bbc4 100644 --- a/lib/crewai/src/crewai/llm.py +++ b/lib/crewai/src/crewai/llm/core.py @@ -20,9 +20,7 @@ ) from dotenv import load_dotenv -import httpx from pydantic import BaseModel, Field -from typing_extensions import Self from crewai.events.event_bus import crewai_event_bus from crewai.events.types.llm_events import ( @@ -37,14 +35,7 @@ ToolUsageFinishedEvent, ToolUsageStartedEvent, ) -from crewai.llms.base_llm import BaseLLM -from crewai.llms.constants import ( - ANTHROPIC_MODELS, - AZURE_MODELS, - BEDROCK_MODELS, - GEMINI_MODELS, - OPENAI_MODELS, -) +from crewai.llm.base_llm import BaseLLM from crewai.utilities import InternalInstructor from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, @@ -61,7 +52,6 @@ from litellm.utils import supports_response_schema from crewai.agent.core import Agent - from crewai.llms.hooks.base import BaseInterceptor from crewai.task import Task from crewai.tools.base_tool import BaseTool from crewai.utilities.types import LLMMessage @@ -327,249 +317,57 @@ class AccumulatedToolArgs(BaseModel): class LLM(BaseLLM): - completion_cost: float | None = None - - def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM: - """Factory method that routes to native SDK or falls back to LiteLLM. - - Routing priority: - 1. If 'provider' kwarg is present, use that provider with constants - 2. If only 'model' kwarg, use constants to infer provider - 3. If "/" in model name: - - Check if prefix is a native provider (openai/anthropic/azure/bedrock/gemini) - - If yes, validate model against constants - - If valid, route to native SDK; otherwise route to LiteLLM - """ - if not model or not isinstance(model, str): - raise ValueError("Model must be a non-empty string") - - explicit_provider = kwargs.get("provider") - - if explicit_provider: - provider = explicit_provider - use_native = True - model_string = model - elif "/" in model: - prefix, _, model_part = model.partition("/") - - provider_mapping = { - "openai": "openai", - "anthropic": "anthropic", - "claude": "anthropic", - "azure": "azure", - "azure_openai": "azure", - "google": "gemini", - "gemini": "gemini", - "bedrock": "bedrock", - "aws": "bedrock", - } - - canonical_provider = provider_mapping.get(prefix.lower()) - - if canonical_provider and cls._validate_model_in_constants( - model_part, canonical_provider - ): - provider = canonical_provider - use_native = True - model_string = model_part - else: - provider = prefix - use_native = False - model_string = model_part - else: - provider = cls._infer_provider_from_model(model) - use_native = True - model_string = model - - native_class = cls._get_native_provider(provider) if use_native else None - if native_class and not is_litellm and provider in SUPPORTED_NATIVE_PROVIDERS: - try: - # Remove 'provider' from kwargs if it exists to avoid duplicate keyword argument - kwargs_copy = {k: v for k, v in kwargs.items() if k != 'provider'} - return cast( - Self, native_class(model=model_string, provider=provider, **kwargs_copy) - ) - except NotImplementedError: - raise - except Exception as e: - raise ImportError(f"Error importing native provider: {e}") from e - - # FALLBACK to LiteLLM - if not LITELLM_AVAILABLE: - logger.error("LiteLLM is not available, falling back to LiteLLM") - raise ImportError("Fallback to LiteLLM is not available") from None - - instance = object.__new__(cls) - super(LLM, instance).__init__(model=model, is_litellm=True, **kwargs) - instance.is_litellm = True - return instance - - @classmethod - def _validate_model_in_constants(cls, model: str, provider: str) -> bool: - """Validate if a model name exists in the provider's constants. - - Args: - model: The model name to validate - provider: The provider to check against (canonical name) - - Returns: - True if the model exists in the provider's constants, False otherwise - """ - if provider == "openai": - return model in OPENAI_MODELS - - if provider == "anthropic" or provider == "claude": - return model in ANTHROPIC_MODELS - - if provider == "gemini": - return model in GEMINI_MODELS - - if provider == "bedrock": - return model in BEDROCK_MODELS - - if provider == "azure": - # azure does not provide a list of available models, determine a better way to handle this - return True - - return False + """LiteLLM-based LLM implementation for CrewAI. + + This class provides LiteLLM integration for models not covered by native providers. + The metaclass (LLMMeta) automatically routes to native providers when appropriate. + """ + + # LiteLLM-specific fields + completion_cost: float | None = Field(None, description="Cost of completion") + timeout: float | int | None = Field(None, description="Request timeout") + top_p: float | None = Field(None, description="Top-p sampling parameter") + n: int | None = Field(None, description="Number of completions to generate") + max_completion_tokens: int | None = Field( + None, description="Maximum completion tokens" + ) + max_tokens: int | float | None = Field(None, description="Maximum total tokens") + presence_penalty: float | None = Field(None, description="Presence penalty") + frequency_penalty: float | None = Field(None, description="Frequency penalty") + logit_bias: dict[int, float] | None = Field(None, description="Logit bias") + response_format: type[BaseModel] | None = Field( + None, description="Response format model" + ) + seed: int | None = Field(None, description="Random seed for reproducibility") + logprobs: int | None = Field(None, description="Log probabilities to return") + top_logprobs: int | None = Field(None, description="Top log probabilities") + api_base: str | None = Field(None, description="API base URL (alias for base_url)") + api_version: str | None = Field(None, description="API version") + callbacks: list[Any] | None = Field(None, description="Callback functions") + context_window_size: int = Field(0, description="Context window size in tokens") + reasoning_effort: Literal["none", "low", "medium", "high"] | None = Field( + None, description="Reasoning effort level" + ) + is_anthropic: bool = Field(False, description="Whether model is from Anthropic") + stream: bool = Field(False, description="Whether to stream responses") - @classmethod - def _infer_provider_from_model(cls, model: str) -> str: - """Infer the provider from the model name. + def model_post_init(self, __context: Any) -> None: + """Initialize LiteLLM-specific settings after model initialization. Args: - model: The model name without provider prefix - - Returns: - The inferred provider name, defaults to "openai" + __context: Pydantic context """ + super().model_post_init(__context) - if model in OPENAI_MODELS: - return "openai" - - if model in ANTHROPIC_MODELS: - return "anthropic" - - if model in GEMINI_MODELS: - return "gemini" - - if model in BEDROCK_MODELS: - return "bedrock" - - if model in AZURE_MODELS: - return "azure" - - return "openai" - - @classmethod - def _get_native_provider(cls, provider: str) -> type | None: - """Get native provider class if available.""" - if provider == "openai": - from crewai.llms.providers.openai.completion import OpenAICompletion - - return OpenAICompletion - - if provider == "anthropic" or provider == "claude": - from crewai.llms.providers.anthropic.completion import ( - AnthropicCompletion, - ) - - return AnthropicCompletion - - if provider == "azure" or provider == "azure_openai": - from crewai.llms.providers.azure.completion import AzureCompletion + # Configure LiteLLM + if LITELLM_AVAILABLE: + litellm.drop_params = True - return AzureCompletion + # Determine if this is an Anthropic model + self.is_anthropic = self._is_anthropic_model(self.model) - if provider == "google" or provider == "gemini": - from crewai.llms.providers.gemini.completion import GeminiCompletion - - return GeminiCompletion - - if provider == "bedrock": - from crewai.llms.providers.bedrock.completion import BedrockCompletion - - return BedrockCompletion - - return None - - def __init__( - self, - model: str, - timeout: float | int | None = None, - temperature: float | None = None, - top_p: float | None = None, - n: int | None = None, - stop: str | list[str] | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | float | None = None, - presence_penalty: float | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[int, float] | None = None, - response_format: type[BaseModel] | None = None, - seed: int | None = None, - logprobs: int | None = None, - top_logprobs: int | None = None, - base_url: str | None = None, - api_base: str | None = None, - api_version: str | None = None, - api_key: str | None = None, - callbacks: list[Any] | None = None, - reasoning_effort: Literal["none", "low", "medium", "high"] | None = None, - stream: bool = False, - interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, - **kwargs: Any, - ) -> None: - """Initialize LLM instance. - - Note: This __init__ method is only called for fallback instances. - Native provider instances handle their own initialization in their respective classes. - """ - super().__init__( - model=model, - temperature=temperature, - api_key=api_key, - base_url=base_url, - timeout=timeout, - **kwargs, - ) - self.model = model - self.timeout = timeout - self.temperature = temperature - self.top_p = top_p - self.n = n - self.max_completion_tokens = max_completion_tokens - self.max_tokens = max_tokens - self.presence_penalty = presence_penalty - self.frequency_penalty = frequency_penalty - self.logit_bias = logit_bias - self.response_format = response_format - self.seed = seed - self.logprobs = logprobs - self.top_logprobs = top_logprobs - self.base_url = base_url - self.api_base = api_base - self.api_version = api_version - self.api_key = api_key - self.callbacks = callbacks - self.context_window_size = 0 - self.reasoning_effort = reasoning_effort - self.additional_params = kwargs - self.is_anthropic = self._is_anthropic_model(model) - self.stream = stream - self.interceptor = interceptor - - litellm.drop_params = True - - # Normalize self.stop to always be a list[str] - if stop is None: - self.stop: list[str] = [] - elif isinstance(stop, str): - self.stop = [stop] - else: - self.stop = stop - - self.set_callbacks(callbacks or []) + # Set up callbacks + self.set_callbacks(self.callbacks or []) self.set_env_callbacks() @staticmethod @@ -1649,7 +1447,7 @@ def __copy__(self) -> LLM: **filtered_params, ) - def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM: + def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM: # type: ignore[override] """Create a deep copy of the LLM instance.""" import copy diff --git a/lib/crewai/src/crewai/llms/hooks/__init__.py b/lib/crewai/src/crewai/llm/hooks/__init__.py similarity index 58% rename from lib/crewai/src/crewai/llms/hooks/__init__.py rename to lib/crewai/src/crewai/llm/hooks/__init__.py index 2bbad217d6..5c949294eb 100644 --- a/lib/crewai/src/crewai/llms/hooks/__init__.py +++ b/lib/crewai/src/crewai/llm/hooks/__init__.py @@ -1,6 +1,6 @@ """Interceptor contracts for crewai""" -from crewai.llms.hooks.base import BaseInterceptor +from crewai.llm.hooks.base import BaseInterceptor __all__ = ["BaseInterceptor"] diff --git a/lib/crewai/src/crewai/llms/hooks/base.py b/lib/crewai/src/crewai/llm/hooks/base.py similarity index 100% rename from lib/crewai/src/crewai/llms/hooks/base.py rename to lib/crewai/src/crewai/llm/hooks/base.py diff --git a/lib/crewai/src/crewai/llms/hooks/transport.py b/lib/crewai/src/crewai/llm/hooks/transport.py similarity index 98% rename from lib/crewai/src/crewai/llms/hooks/transport.py rename to lib/crewai/src/crewai/llm/hooks/transport.py index 27a0972aba..db99a0d597 100644 --- a/lib/crewai/src/crewai/llms/hooks/transport.py +++ b/lib/crewai/src/crewai/llm/hooks/transport.py @@ -22,7 +22,7 @@ from httpx import Limits, Request, Response from httpx._types import CertTypes, ProxyTypes - from crewai.llms.hooks.base import BaseInterceptor + from crewai.llm.hooks.base import BaseInterceptor class HTTPTransportKwargs(TypedDict, total=False): diff --git a/lib/crewai/src/crewai/llms/providers/__init__.py b/lib/crewai/src/crewai/llm/internal/__init__.py similarity index 100% rename from lib/crewai/src/crewai/llms/providers/__init__.py rename to lib/crewai/src/crewai/llm/internal/__init__.py diff --git a/lib/crewai/src/crewai/llm/internal/meta.py b/lib/crewai/src/crewai/llm/internal/meta.py new file mode 100644 index 0000000000..4bf83b655a --- /dev/null +++ b/lib/crewai/src/crewai/llm/internal/meta.py @@ -0,0 +1,232 @@ +"""Metaclass for LLM provider routing. + +This metaclass enables automatic routing to native provider implementations +based on the model parameter at instantiation time. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from pydantic._internal._model_construction import ModelMetaclass + + +# Provider constants imported from crewai.llm.constants +SUPPORTED_NATIVE_PROVIDERS: list[str] = [ + "openai", + "anthropic", + "claude", + "azure", + "azure_openai", + "google", + "gemini", + "bedrock", + "aws", +] + + +class LLMMeta(ModelMetaclass): + """Metaclass for LLM that handles provider routing. + + This metaclass intercepts LLM instantiation and routes to the appropriate + native provider implementation based on the model parameter. + """ + + def __call__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> Any: # noqa: N805 + """Route to appropriate provider implementation at instantiation time. + + Args: + model: The model identifier (e.g., "gpt-4", "claude-3-opus") + is_litellm: Force use of LiteLLM instead of native provider + **kwargs: Additional parameters for the LLM + + Returns: + Instance of the appropriate provider class or LLM class + + Raises: + ValueError: If model is not a valid string + """ + if not model or not isinstance(model, str): + raise ValueError("Model must be a non-empty string") + + # Only perform routing if called on the base LLM class + # Subclasses (OpenAICompletion, etc.) should create normally + from crewai.llm import LLM + + if cls is not LLM: + # Direct instantiation of provider class, skip routing + return super().__call__(model=model, **kwargs) + + # Extract provider information + explicit_provider = kwargs.get("provider") + + if explicit_provider: + provider = explicit_provider + use_native = True + model_string = model + elif "/" in model: + prefix, _, model_part = model.partition("/") + + provider_mapping = { + "openai": "openai", + "anthropic": "anthropic", + "claude": "anthropic", + "azure": "azure", + "azure_openai": "azure", + "google": "gemini", + "gemini": "gemini", + "bedrock": "bedrock", + "aws": "bedrock", + } + + canonical_provider = provider_mapping.get(prefix.lower()) + + if canonical_provider and cls._validate_model_in_constants( + model_part, canonical_provider + ): + provider = canonical_provider + use_native = True + model_string = model_part + else: + provider = prefix + use_native = False + model_string = model_part + else: + provider = cls._infer_provider_from_model(model) + use_native = True + model_string = model + + # Route to native provider if available + native_class = cls._get_native_provider(provider) if use_native else None + if native_class and not is_litellm and provider in SUPPORTED_NATIVE_PROVIDERS: + try: + # Remove 'provider' from kwargs to avoid duplicate keyword argument + kwargs_copy = {k: v for k, v in kwargs.items() if k != "provider"} + return native_class( + model=model_string, provider=provider, **kwargs_copy + ) + except NotImplementedError: + raise + except Exception as e: + raise ImportError(f"Error importing native provider: {e}") from e + + # Fallback to LiteLLM + try: + import litellm # noqa: F401 + except ImportError: + logging.error("LiteLLM is not available, falling back to LiteLLM") + raise ImportError("Fallback to LiteLLM is not available") from None + + # Create actual LLM instance with is_litellm=True + return super().__call__(model=model, is_litellm=True, **kwargs) + + @staticmethod + def _validate_model_in_constants(model: str, provider: str) -> bool: + """Validate if a model name exists in the provider's constants. + + Args: + model: The model name to validate + provider: The provider to check against (canonical name) + + Returns: + True if the model exists in the provider's constants, False otherwise + """ + from crewai.llm.constants import ( + ANTHROPIC_MODELS, + BEDROCK_MODELS, + GEMINI_MODELS, + OPENAI_MODELS, + ) + + if provider == "openai": + return model in OPENAI_MODELS + + if provider == "anthropic" or provider == "claude": + return model in ANTHROPIC_MODELS + + if provider == "gemini": + return model in GEMINI_MODELS + + if provider == "bedrock": + return model in BEDROCK_MODELS + + if provider == "azure": + # azure does not provide a list of available models + return True + + return False + + @staticmethod + def _infer_provider_from_model(model: str) -> str: + """Infer the provider from the model name. + + Args: + model: The model name without provider prefix + + Returns: + The inferred provider name, defaults to "openai" + """ + from crewai.llm.constants import ( + ANTHROPIC_MODELS, + AZURE_MODELS, + BEDROCK_MODELS, + GEMINI_MODELS, + OPENAI_MODELS, + ) + + if model in OPENAI_MODELS: + return "openai" + + if model in ANTHROPIC_MODELS: + return "anthropic" + + if model in GEMINI_MODELS: + return "gemini" + + if model in BEDROCK_MODELS: + return "bedrock" + + if model in AZURE_MODELS: + return "azure" + + return "openai" + + @staticmethod + def _get_native_provider(provider: str) -> type | None: + """Get native provider class if available. + + Args: + provider: The provider name + + Returns: + The provider class or None if not available + """ + if provider == "openai": + from crewai.llm.providers.openai.completion import OpenAICompletion + + return OpenAICompletion + + if provider == "anthropic" or provider == "claude": + from crewai.llm.providers.anthropic.completion import ( + AnthropicCompletion, + ) + + return AnthropicCompletion + + if provider == "azure" or provider == "azure_openai": + from crewai.llm.providers.azure.completion import AzureCompletion + + return AzureCompletion + + if provider == "google" or provider == "gemini": + from crewai.llm.providers.gemini.completion import GeminiCompletion + + return GeminiCompletion + + if provider == "bedrock": + from crewai.llm.providers.bedrock.completion import BedrockCompletion + + return BedrockCompletion + + return None diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/__init__.py b/lib/crewai/src/crewai/llm/providers/__init__.py similarity index 100% rename from lib/crewai/src/crewai/llms/providers/anthropic/__init__.py rename to lib/crewai/src/crewai/llm/providers/__init__.py diff --git a/lib/crewai/src/crewai/llms/providers/azure/__init__.py b/lib/crewai/src/crewai/llm/providers/anthropic/__init__.py similarity index 100% rename from lib/crewai/src/crewai/llms/providers/azure/__init__.py rename to lib/crewai/src/crewai/llm/providers/anthropic/__init__.py diff --git a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py similarity index 94% rename from lib/crewai/src/crewai/llms/providers/anthropic/completion.py rename to lib/crewai/src/crewai/llm/providers/anthropic/completion.py index ea161fc635..8ebba16737 100644 --- a/lib/crewai/src/crewai/llms/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py @@ -3,13 +3,15 @@ import json import logging import os -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from crewai.events.types.llm_events import LLMCallType -from crewai.llms.base_llm import BaseLLM -from crewai.llms.hooks.transport import HTTPTransport +from crewai.llm.base_llm import BaseLLM +from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO +from crewai.llm.hooks.transport import HTTPTransport +from crewai.llm.providers.utils.common import safe_tool_conversion from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, @@ -18,7 +20,7 @@ if TYPE_CHECKING: - from crewai.llms.hooks.base import BaseInterceptor + from crewai.llm.hooks.base import BaseInterceptor try: from anthropic import Anthropic @@ -38,6 +40,8 @@ class AnthropicCompletion(BaseLLM): offering native tool use, streaming support, and proper message formatting. """ + model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,)) + def __init__( self, model: str = "claude-3-5-sonnet-20241022", @@ -94,29 +98,30 @@ def __init__( self.is_claude_3 = "claude-3" in model.lower() self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use - @property - def stop(self) -> list[str]: - """Get stop sequences sent to the API.""" - return self.stop_sequences - - @stop.setter - def stop(self, value: list[str] | str | None) -> None: - """Set stop sequences. - - Synchronizes stop_sequences to ensure values set by CrewAgentExecutor - are properly sent to the Anthropic API. - - Args: - value: Stop sequences as a list, single string, or None - """ - if value is None: - self.stop_sequences = [] - elif isinstance(value, str): - self.stop_sequences = [value] - elif isinstance(value, list): - self.stop_sequences = value - else: - self.stop_sequences = [] + # + # @property + # def stop(self) -> list[str]: # type: ignore[misc] + # """Get stop sequences sent to the API.""" + # return self.stop_sequences + + # @stop.setter + # def stop(self, value: list[str] | str | None) -> None: + # """Set stop sequences. + # + # Synchronizes stop_sequences to ensure values set by CrewAgentExecutor + # are properly sent to the Anthropic API. + # + # Args: + # value: Stop sequences as a list, single string, or None + # """ + # if value is None: + # self.stop_sequences = [] + # elif isinstance(value, str): + # self.stop_sequences = [value] + # elif isinstance(value, list): + # self.stop_sequences = value + # else: + # self.stop_sequences = [] def _get_client_params(self) -> dict[str, Any]: """Get client parameters.""" @@ -266,8 +271,6 @@ def _convert_tools_for_interference( continue try: - from crewai.llms.providers.utils.common import safe_tool_conversion - name, description, parameters = safe_tool_conversion(tool, "Anthropic") except (ImportError, KeyError, ValueError) as e: logging.error(f"Error converting tool to Anthropic format: {e}") @@ -636,7 +639,6 @@ def supports_stop_words(self) -> bool: def get_context_window_size(self) -> int: """Get the context window size for the model.""" - from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO # Context window sizes for Anthropic models context_windows = { diff --git a/lib/crewai/src/crewai/llms/providers/bedrock/__init__.py b/lib/crewai/src/crewai/llm/providers/azure/__init__.py similarity index 100% rename from lib/crewai/src/crewai/llms/providers/bedrock/__init__.py rename to lib/crewai/src/crewai/llm/providers/azure/__init__.py diff --git a/lib/crewai/src/crewai/llms/providers/azure/completion.py b/lib/crewai/src/crewai/llm/providers/azure/completion.py similarity index 98% rename from lib/crewai/src/crewai/llms/providers/azure/completion.py rename to lib/crewai/src/crewai/llm/providers/azure/completion.py index 17306d8a28..a389c18255 100644 --- a/lib/crewai/src/crewai/llms/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llm/providers/azure/completion.py @@ -7,6 +7,8 @@ from pydantic import BaseModel +from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES +from crewai.llm.providers.utils.common import safe_tool_conversion from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, @@ -15,7 +17,7 @@ if TYPE_CHECKING: - from crewai.llms.hooks.base import BaseInterceptor + from crewai.llm.hooks.base import BaseInterceptor from crewai.tools.base_tool import BaseTool @@ -36,7 +38,7 @@ ) from crewai.events.types.llm_events import LLMCallType - from crewai.llms.base_llm import BaseLLM + from crewai.llm.base_llm import BaseLLM except ImportError: raise ImportError( @@ -317,8 +319,6 @@ def _convert_tools_for_interference( ) -> list[dict[str, Any]]: """Convert CrewAI tool format to Azure OpenAI function calling format.""" - from crewai.llms.providers.utils.common import safe_tool_conversion - azure_tools = [] for tool in tools: @@ -554,7 +554,6 @@ def supports_stop_words(self) -> bool: def get_context_window_size(self) -> int: """Get the context window size for the model.""" - from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES min_context = 1024 max_context = 2097152 diff --git a/lib/crewai/src/crewai/llms/providers/gemini/__init__.py b/lib/crewai/src/crewai/llm/providers/bedrock/__init__.py similarity index 100% rename from lib/crewai/src/crewai/llms/providers/gemini/__init__.py rename to lib/crewai/src/crewai/llm/providers/bedrock/__init__.py diff --git a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py similarity index 96% rename from lib/crewai/src/crewai/llms/providers/bedrock/completion.py rename to lib/crewai/src/crewai/llm/providers/bedrock/completion.py index 20eabf7630..f67414c639 100644 --- a/lib/crewai/src/crewai/llms/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py @@ -3,13 +3,15 @@ from collections.abc import Mapping, Sequence import logging import os -from typing import TYPE_CHECKING, Any, TypedDict, cast +from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from typing_extensions import Required from crewai.events.types.llm_events import LLMCallType -from crewai.llms.base_llm import BaseLLM +from crewai.llm.base_llm import BaseLLM +from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO +from crewai.llm.providers.utils.common import safe_tool_conversion from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, @@ -30,7 +32,7 @@ ToolTypeDef, ) - from crewai.llms.hooks.base import BaseInterceptor + from crewai.llm.hooks.base import BaseInterceptor try: @@ -143,6 +145,8 @@ class BedrockCompletion(BaseLLM): - Model-specific conversation format handling (e.g., Cohere requirements) """ + model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,)) + def __init__( self, model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0", @@ -243,29 +247,29 @@ def __init__( # Handle inference profiles for newer models self.model_id = model - @property - def stop(self) -> list[str]: - """Get stop sequences sent to the API.""" - return list(self.stop_sequences) - - @stop.setter - def stop(self, value: Sequence[str] | str | None) -> None: - """Set stop sequences. - - Synchronizes stop_sequences to ensure values set by CrewAgentExecutor - are properly sent to the Bedrock API. - - Args: - value: Stop sequences as a Sequence, single string, or None - """ - if value is None: - self.stop_sequences = [] - elif isinstance(value, str): - self.stop_sequences = [value] - elif isinstance(value, Sequence): - self.stop_sequences = list(value) - else: - self.stop_sequences = [] + # @property + # def stop(self) -> list[str]: # type: ignore[misc] + # """Get stop sequences sent to the API.""" + # return list(self.stop_sequences) + + # @stop.setter + # def stop(self, value: Sequence[str] | str | None) -> None: + # """Set stop sequences. + # + # Synchronizes stop_sequences to ensure values set by CrewAgentExecutor + # are properly sent to the Bedrock API. + # + # Args: + # value: Stop sequences as a Sequence, single string, or None + # """ + # if value is None: + # self.stop_sequences = [] + # elif isinstance(value, str): + # self.stop_sequences = [value] + # elif isinstance(value, Sequence): + # self.stop_sequences = list(value) + # else: + # self.stop_sequences = [] def call( self, @@ -778,7 +782,6 @@ def _format_tools_for_converse( tools: list[dict[str, Any]], ) -> list[ConverseToolTypeDef]: """Convert CrewAI tools to Converse API format following AWS specification.""" - from crewai.llms.providers.utils.common import safe_tool_conversion converse_tools: list[ConverseToolTypeDef] = [] @@ -871,7 +874,6 @@ def supports_stop_words(self) -> bool: def get_context_window_size(self) -> int: """Get the context window size for the model.""" - from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO # Context window sizes for common Bedrock models context_windows = { diff --git a/lib/crewai/src/crewai/llms/providers/openai/__init__.py b/lib/crewai/src/crewai/llm/providers/gemini/__init__.py similarity index 100% rename from lib/crewai/src/crewai/llms/providers/openai/__init__.py rename to lib/crewai/src/crewai/llm/providers/gemini/__init__.py diff --git a/lib/crewai/src/crewai/llms/providers/gemini/completion.py b/lib/crewai/src/crewai/llm/providers/gemini/completion.py similarity index 94% rename from lib/crewai/src/crewai/llms/providers/gemini/completion.py rename to lib/crewai/src/crewai/llm/providers/gemini/completion.py index 8668a8f580..263309910f 100644 --- a/lib/crewai/src/crewai/llms/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llm/providers/gemini/completion.py @@ -1,12 +1,14 @@ import logging import os -from typing import Any, cast +from typing import Any, ClassVar, cast -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from crewai.events.types.llm_events import LLMCallType -from crewai.llms.base_llm import BaseLLM -from crewai.llms.hooks.base import BaseInterceptor +from crewai.llm.base_llm import BaseLLM +from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES +from crewai.llm.hooks.base import BaseInterceptor +from crewai.llm.providers.utils.common import safe_tool_conversion from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, @@ -31,6 +33,8 @@ class GeminiCompletion(BaseLLM): offering native function calling, streaming support, and proper Gemini formatting. """ + model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,)) + def __init__( self, model: str = "gemini-2.0-flash-001", @@ -104,29 +108,29 @@ def __init__( self.is_gemini_1_5 = "gemini-1.5" in model.lower() self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2 - @property - def stop(self) -> list[str]: - """Get stop sequences sent to the API.""" - return self.stop_sequences - - @stop.setter - def stop(self, value: list[str] | str | None) -> None: - """Set stop sequences. - - Synchronizes stop_sequences to ensure values set by CrewAgentExecutor - are properly sent to the Gemini API. - - Args: - value: Stop sequences as a list, single string, or None - """ - if value is None: - self.stop_sequences = [] - elif isinstance(value, str): - self.stop_sequences = [value] - elif isinstance(value, list): - self.stop_sequences = value - else: - self.stop_sequences = [] + # @property + # def stop(self) -> list[str]: # type: ignore[misc] + # """Get stop sequences sent to the API.""" + # return self.stop_sequences + + # @stop.setter + # def stop(self, value: list[str] | str | None) -> None: + # """Set stop sequences. + # + # Synchronizes stop_sequences to ensure values set by CrewAgentExecutor + # are properly sent to the Gemini API. + # + # Args: + # value: Stop sequences as a list, single string, or None + # """ + # if value is None: + # self.stop_sequences = [] + # elif isinstance(value, str): + # self.stop_sequences = [value] + # elif isinstance(value, list): + # self.stop_sequences = value + # else: + # self.stop_sequences = [] def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # type: ignore[no-any-unimported] """Initialize the Google Gen AI client with proper parameter handling. @@ -335,8 +339,6 @@ def _convert_tools_for_interference( # type: ignore[no-any-unimported] """Convert CrewAI tool format to Gemini function declaration format.""" gemini_tools = [] - from crewai.llms.providers.utils.common import safe_tool_conversion - for tool in tools: name, description, parameters = safe_tool_conversion(tool, "Gemini") @@ -547,7 +549,6 @@ def supports_stop_words(self) -> bool: def get_context_window_size(self) -> int: """Get the context window size for the model.""" - from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES min_context = 1024 max_context = 2097152 diff --git a/lib/crewai/src/crewai/llms/providers/utils/__init__.py b/lib/crewai/src/crewai/llm/providers/openai/__init__.py similarity index 100% rename from lib/crewai/src/crewai/llms/providers/utils/__init__.py rename to lib/crewai/src/crewai/llm/providers/openai/__init__.py diff --git a/lib/crewai/src/crewai/llms/providers/openai/completion.py b/lib/crewai/src/crewai/llm/providers/openai/completion.py similarity index 87% rename from lib/crewai/src/crewai/llms/providers/openai/completion.py rename to lib/crewai/src/crewai/llm/providers/openai/completion.py index fdf7b03c79..f9f65c8b13 100644 --- a/lib/crewai/src/crewai/llms/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llm/providers/openai/completion.py @@ -11,11 +11,13 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import ChoiceDelta -from pydantic import BaseModel +from pydantic import BaseModel, Field from crewai.events.types.llm_events import LLMCallType -from crewai.llms.base_llm import BaseLLM -from crewai.llms.hooks.transport import HTTPTransport +from crewai.llm.base_llm import BaseLLM +from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES +from crewai.llm.hooks.transport import HTTPTransport +from crewai.llm.providers.utils.common import safe_tool_conversion from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededError, @@ -25,7 +27,6 @@ if TYPE_CHECKING: from crewai.agent.core import Agent - from crewai.llms.hooks.base import BaseInterceptor from crewai.task import Task from crewai.tools.base_tool import BaseTool @@ -37,61 +38,61 @@ class OpenAICompletion(BaseLLM): offering native structured outputs, function calling, and streaming support. """ - def __init__( - self, - model: str = "gpt-4o", - api_key: str | None = None, - base_url: str | None = None, - organization: str | None = None, - project: str | None = None, - timeout: float | None = None, - max_retries: int = 2, - default_headers: dict[str, str] | None = None, - default_query: dict[str, Any] | None = None, - client_params: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - frequency_penalty: float | None = None, - presence_penalty: float | None = None, - max_tokens: int | None = None, - max_completion_tokens: int | None = None, - seed: int | None = None, - stream: bool = False, - response_format: dict[str, Any] | type[BaseModel] | None = None, - logprobs: bool | None = None, - top_logprobs: int | None = None, - reasoning_effort: str | None = None, - provider: str | None = None, - interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, - **kwargs: Any, - ) -> None: - """Initialize OpenAI chat completion client.""" - - if provider is None: - provider = kwargs.pop("provider", "openai") - - self.interceptor = interceptor - # Client configuration attributes - self.organization = organization - self.project = project - self.max_retries = max_retries - self.default_headers = default_headers - self.default_query = default_query - self.client_params = client_params - self.timeout = timeout - self.base_url = base_url - self.api_base = kwargs.pop("api_base", None) - - super().__init__( - model=model, - temperature=temperature, - api_key=api_key or os.getenv("OPENAI_API_KEY"), - base_url=base_url, - timeout=timeout, - provider=provider, - **kwargs, - ) + # Client configuration fields + organization: str | None = Field(None, description="OpenAI organization ID") + project: str | None = Field(None, description="OpenAI project ID") + max_retries: int = Field(2, description="Maximum number of retries") + default_headers: dict[str, str] | None = Field( + None, description="Default headers for requests" + ) + default_query: dict[str, Any] | None = Field( + None, description="Default query parameters" + ) + client_params: dict[str, Any] | None = Field( + None, description="Additional client parameters" + ) + timeout: float | None = Field(None, description="Request timeout") + api_base: str | None = Field(None, description="API base URL (deprecated)") + + # Completion parameters + top_p: float | None = Field(None, description="Top-p sampling parameter") + frequency_penalty: float | None = Field(None, description="Frequency penalty") + presence_penalty: float | None = Field(None, description="Presence penalty") + max_tokens: int | None = Field(None, description="Maximum tokens") + max_completion_tokens: int | None = Field( + None, description="Maximum completion tokens" + ) + seed: int | None = Field(None, description="Random seed") + stream: bool = Field(False, description="Enable streaming") + response_format: dict[str, Any] | type[BaseModel] | None = Field( + None, description="Response format" + ) + logprobs: bool | None = Field(None, description="Return log probabilities") + top_logprobs: int | None = Field( + None, description="Number of top log probabilities" + ) + reasoning_effort: str | None = Field(None, description="Reasoning effort level") + + # Internal state + client: OpenAI = Field( + default_factory=OpenAI, exclude=True, description="OpenAI client instance" + ) + is_o1_model: bool = Field(False, description="Whether this is an O1 model") + is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model") + + def model_post_init(self, __context: Any) -> None: + """Initialize OpenAI client after model initialization. + + Args: + __context: Pydantic context + """ + super().model_post_init(__context) + + # Set API key from environment if not provided + if self.api_key is None: + self.api_key = os.getenv("OPENAI_API_KEY") + # Initialize client client_config = self._get_client_params() if self.interceptor: transport = HTTPTransport(interceptor=self.interceptor) @@ -100,20 +101,9 @@ def __init__( self.client = OpenAI(**client_config) - # Completion parameters - self.top_p = top_p - self.frequency_penalty = frequency_penalty - self.presence_penalty = presence_penalty - self.max_tokens = max_tokens - self.max_completion_tokens = max_completion_tokens - self.seed = seed - self.stream = stream - self.response_format = response_format - self.logprobs = logprobs - self.top_logprobs = top_logprobs - self.reasoning_effort = reasoning_effort - self.is_o1_model = "o1" in model.lower() - self.is_gpt4_model = "gpt-4" in model.lower() + # Set model flags + self.is_o1_model = "o1" in self.model.lower() + self.is_gpt4_model = "gpt-4" in self.model.lower() def _get_client_params(self) -> dict[str, Any]: """Get OpenAI client parameters.""" @@ -268,7 +258,6 @@ def _convert_tools_for_interference( self, tools: list[dict[str, BaseTool]] ) -> list[dict[str, Any]]: """Convert CrewAI tool format to OpenAI function calling format.""" - from crewai.llms.providers.utils.common import safe_tool_conversion openai_tools = [] @@ -560,7 +549,6 @@ def supports_stop_words(self) -> bool: def get_context_window_size(self) -> int: """Get the context window size for the model.""" - from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES min_context = 1024 max_context = 2097152 diff --git a/lib/crewai/src/crewai/llm/providers/utils/__init__.py b/lib/crewai/src/crewai/llm/providers/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lib/crewai/src/crewai/llms/providers/utils/common.py b/lib/crewai/src/crewai/llm/providers/utils/common.py similarity index 100% rename from lib/crewai/src/crewai/llms/providers/utils/common.py rename to lib/crewai/src/crewai/llm/providers/utils/common.py diff --git a/lib/crewai/src/crewai/llms/third_party/__init__.py b/lib/crewai/src/crewai/llms/third_party/__init__.py deleted file mode 100644 index 947a62fa46..0000000000 --- a/lib/crewai/src/crewai/llms/third_party/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Third-party LLM implementations for crewAI.""" diff --git a/lib/crewai/src/crewai/tasks/llm_guardrail.py b/lib/crewai/src/crewai/tasks/llm_guardrail.py index 803b2d7492..b540feebf8 100644 --- a/lib/crewai/src/crewai/tasks/llm_guardrail.py +++ b/lib/crewai/src/crewai/tasks/llm_guardrail.py @@ -4,7 +4,7 @@ from crewai.agent import Agent from crewai.lite_agent_output import LiteAgentOutput -from crewai.llms.base_llm import BaseLLM +from crewai.llm.base_llm import BaseLLM from crewai.tasks.task_output import TaskOutput diff --git a/lib/crewai/src/crewai/utilities/agent_utils.py b/lib/crewai/src/crewai/utilities/agent_utils.py index a6403c3157..6fc8bd863f 100644 --- a/lib/crewai/src/crewai/utilities/agent_utils.py +++ b/lib/crewai/src/crewai/utilities/agent_utils.py @@ -16,7 +16,7 @@ parse, ) from crewai.cli.config import Settings -from crewai.llms.base_llm import BaseLLM +from crewai.llm.base_llm import BaseLLM from crewai.tools import BaseTool as CrewAITool from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool diff --git a/lib/crewai/src/crewai/utilities/converter.py b/lib/crewai/src/crewai/utilities/converter.py index 0a42a467e4..ce827ce824 100644 --- a/lib/crewai/src/crewai/utilities/converter.py +++ b/lib/crewai/src/crewai/utilities/converter.py @@ -19,7 +19,7 @@ from crewai.agent import Agent from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.llm import LLM - from crewai.llms.base_llm import BaseLLM + from crewai.llm.base_llm import BaseLLM _JSON_PATTERN: Final[re.Pattern[str]] = re.compile(r"({.*})", re.DOTALL) _I18N = get_i18n() diff --git a/lib/crewai/src/crewai/utilities/evaluators/crew_evaluator_handler.py b/lib/crewai/src/crewai/utilities/evaluators/crew_evaluator_handler.py index 9c9cac0c69..5420cb4d4a 100644 --- a/lib/crewai/src/crewai/utilities/evaluators/crew_evaluator_handler.py +++ b/lib/crewai/src/crewai/utilities/evaluators/crew_evaluator_handler.py @@ -11,7 +11,7 @@ from crewai.agent import Agent from crewai.events.event_bus import crewai_event_bus from crewai.events.types.crew_events import CrewTestResultEvent -from crewai.llms.base_llm import BaseLLM +from crewai.llm.base_llm import BaseLLM from crewai.task import Task from crewai.tasks.task_output import TaskOutput diff --git a/lib/crewai/src/crewai/utilities/internal_instructor.py b/lib/crewai/src/crewai/utilities/internal_instructor.py index 06a95d2344..72d0bf75b0 100644 --- a/lib/crewai/src/crewai/utilities/internal_instructor.py +++ b/lib/crewai/src/crewai/utilities/internal_instructor.py @@ -10,7 +10,7 @@ if TYPE_CHECKING: from crewai.agent import Agent from crewai.llm import LLM - from crewai.llms.base_llm import BaseLLM + from crewai.llm.base_llm import BaseLLM from crewai.utilities.types import LLMMessage diff --git a/lib/crewai/src/crewai/utilities/llm_utils.py b/lib/crewai/src/crewai/utilities/llm_utils.py index 129f064d56..506425534e 100644 --- a/lib/crewai/src/crewai/utilities/llm_utils.py +++ b/lib/crewai/src/crewai/utilities/llm_utils.py @@ -4,7 +4,7 @@ from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS from crewai.llm import LLM -from crewai.llms.base_llm import BaseLLM +from crewai.llm.base_llm import BaseLLM logger = logging.getLogger(__name__) diff --git a/lib/crewai/src/crewai/utilities/planning_handler.py b/lib/crewai/src/crewai/utilities/planning_handler.py index c76153020d..338bf11ee9 100644 --- a/lib/crewai/src/crewai/utilities/planning_handler.py +++ b/lib/crewai/src/crewai/utilities/planning_handler.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, Field from crewai.agent import Agent -from crewai.llms.base_llm import BaseLLM +from crewai.llm.base_llm import BaseLLM from crewai.task import Task diff --git a/lib/crewai/src/crewai/utilities/tool_utils.py b/lib/crewai/src/crewai/utilities/tool_utils.py index eb433c02c9..186db1adf9 100644 --- a/lib/crewai/src/crewai/utilities/tool_utils.py +++ b/lib/crewai/src/crewai/utilities/tool_utils.py @@ -15,7 +15,7 @@ from crewai.agent import Agent from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.llm import LLM - from crewai.llms.base_llm import BaseLLM + from crewai.llm.base_llm import BaseLLM from crewai.task import Task diff --git a/lib/crewai/tests/agents/test_agent.py b/lib/crewai/tests/agents/test_agent.py index 4fd1f3b5b3..94e716b97c 100644 --- a/lib/crewai/tests/agents/test_agent.py +++ b/lib/crewai/tests/agents/test_agent.py @@ -14,7 +14,7 @@ from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource from crewai.llm import LLM -from crewai.llms.base_llm import BaseLLM +from crewai.llm.base_llm import BaseLLM from crewai.process import Process from crewai.tools.tool_calling import InstructorToolCalling from crewai.tools.tool_usage import ToolUsage diff --git a/lib/crewai/tests/agents/test_lite_agent.py b/lib/crewai/tests/agents/test_lite_agent.py index 0c6b00c23e..1215c7804d 100644 --- a/lib/crewai/tests/agents/test_lite_agent.py +++ b/lib/crewai/tests/agents/test_lite_agent.py @@ -9,7 +9,7 @@ from crewai.events.types.tool_usage_events import ToolUsageStartedEvent from crewai.lite_agent import LiteAgent from crewai.lite_agent_output import LiteAgentOutput -from crewai.llms.base_llm import BaseLLM +from crewai.llm.base_llm import BaseLLM from pydantic import BaseModel, Field import pytest diff --git a/lib/crewai/tests/cassettes/test_agent_function_calling_llm.yaml b/lib/crewai/tests/cassettes/test_agent_function_calling_llm.yaml index 0136b60c6d..eeb43b9c2e 100644 --- a/lib/crewai/tests/cassettes/test_agent_function_calling_llm.yaml +++ b/lib/crewai/tests/cassettes/test_agent_function_calling_llm.yaml @@ -590,7 +590,7 @@ interactions: " at 0x107389260>", "result_as_answer": "False", "max_usage_count": "None", "current_usage_count": "0"}], "max_iter": 2, "agent_executor": "", - "llm": "", "crew": {"parent_flow": null, "name": "crew", "cache": true, "tasks": ["{''used_tools'': 0, ''tools_errors'': 0, ''delegations'': 0, ''i18n'': {''prompt_file'': None}, ''name'': None, ''prompt_context'': '''', ''description'': @@ -605,7 +605,7 @@ interactions: ''description_updated'': False, ''cache_function'': at 0x107389260>, ''result_as_answer'': False, ''max_usage_count'': None, ''current_usage_count'': 0}], ''max_iter'': 2, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=991ac83f-9a29-411f-b0a0-0a335c7a2d0e, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''description_updated'': False, ''cache_function'': at 0x107389260>, ''result_as_answer'': False, ''max_usage_count'': None, ''current_usage_count'': 0}], ''max_iter'': 2, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=991ac83f-9a29-411f-b0a0-0a335c7a2d0e, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': ", "system_template": null, "prompt_template": null, "response_template": null, "allow_code_execution": false, "respect_context_window": true, "max_retry_limit": 2, "multimodal": false, "inject_date": false, "date_format": "%Y-%m-%d", "code_execution_mode": @@ -1068,7 +1068,7 @@ interactions: " at 0x107e394e0>", "result_as_answer": "False", "max_usage_count": "None", "current_usage_count": "0"}], "max_iter": 2, "agent_executor": "", - "llm": "", "crew": {"parent_flow": null, "name": "crew", "cache": true, "tasks": ["{''used_tools'': 0, ''tools_errors'': 0, ''delegations'': 0, ''i18n'': {''prompt_file'': None}, ''name'': None, ''prompt_context'': '''', ''description'': @@ -1083,7 +1083,7 @@ interactions: ''description_updated'': False, ''cache_function'': at 0x107e394e0>, ''result_as_answer'': False, ''max_usage_count'': None, ''current_usage_count'': 0}], ''max_iter'': 2, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=f38365e9-3206-45b6-8754-950cb03fe57e, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''description_updated'': False, ''cache_function'': at 0x107e394e0>, ''result_as_answer'': False, ''max_usage_count'': None, ''current_usage_count'': 0}], ''max_iter'': 2, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=f38365e9-3206-45b6-8754-950cb03fe57e, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': ", "system_template": null, "prompt_template": null, "response_template": null, "allow_code_execution": false, "respect_context_window": true, "max_retry_limit": 2, "multimodal": false, "inject_date": false, "date_format": "%Y-%m-%d", "code_execution_mode": diff --git a/lib/crewai/tests/cassettes/test_agent_remembers_output_format_after_using_tools_too_many_times.yaml b/lib/crewai/tests/cassettes/test_agent_remembers_output_format_after_using_tools_too_many_times.yaml index a0c8a3e402..e82b2e5d45 100644 --- a/lib/crewai/tests/cassettes/test_agent_remembers_output_format_after_using_tools_too_many_times.yaml +++ b/lib/crewai/tests/cassettes/test_agent_remembers_output_format_after_using_tools_too_many_times.yaml @@ -1274,7 +1274,7 @@ interactions: "b6cf723e-04c8-40c5-a927-e2078cfbae59", "role": "test role", "goal": "test goal", "backstory": "test backstory", "cache": true, "verbose": true, "max_rpm": null, "allow_delegation": false, "tools": [], "max_iter": 6, "agent_executor": "", "llm": "", "llm": "", "crew": {"parent_flow": null, "name": "crew", "cache": true, "tasks": ["{''used_tools'': 0, ''tools_errors'': 0, ''delegations'': 0, ''i18n'': {''prompt_file'': None}, ''name'': None, ''prompt_context'': '''', @@ -1285,7 +1285,7 @@ interactions: ''test role'', ''goal'': ''test goal'', ''backstory'': ''test backstory'', ''cache'': True, ''verbose'': True, ''max_rpm'': None, ''allow_delegation'': False, ''tools'': [], ''max_iter'': 6, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=004dd8a0-dd87-43fa-bdc8-07f449808028, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''llm'': , ''llm'': , ''crew'': Crew(id=004dd8a0-dd87-43fa-bdc8-07f449808028, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': ", "llm": "", "llm": "", "crew": {"parent_flow": null, "name": "crew", "cache": true, "tasks": ["{''used_tools'': 2, ''tools_errors'': 0, ''delegations'': 0, ''i18n'': {''prompt_file'': None}, ''name'': None, ''prompt_context'': '''', @@ -1502,7 +1502,7 @@ interactions: ''test role'', ''goal'': ''test goal'', ''backstory'': ''test backstory'', ''cache'': True, ''verbose'': True, ''max_rpm'': None, ''allow_delegation'': False, ''tools'': [], ''max_iter'': 6, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=004dd8a0-dd87-43fa-bdc8-07f449808028, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''llm'': , ''llm'': , ''crew'': Crew(id=004dd8a0-dd87-43fa-bdc8-07f449808028, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': ", "llm": "", "llm": "", "crew": {"parent_flow": null, "name": "crew", "cache": true, "tasks": ["{''used_tools'': 3, ''tools_errors'': 0, ''delegations'': 0, ''i18n'': {''prompt_file'': None}, ''name'': None, ''prompt_context'': '''', @@ -1671,7 +1671,7 @@ interactions: ''test role'', ''goal'': ''test goal'', ''backstory'': ''test backstory'', ''cache'': True, ''verbose'': True, ''max_rpm'': None, ''allow_delegation'': False, ''tools'': [], ''max_iter'': 6, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=004dd8a0-dd87-43fa-bdc8-07f449808028, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''llm'': , ''llm'': , ''crew'': Crew(id=004dd8a0-dd87-43fa-bdc8-07f449808028, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': ", "llm": "", "llm": "", "crew": {"parent_flow": null, "name": "crew", "cache": true, "tasks": ["{''used_tools'': 4, ''tools_errors'': 0, ''delegations'': 0, ''i18n'': {''prompt_file'': None}, ''name'': None, ''prompt_context'': '''', @@ -1850,7 +1850,7 @@ interactions: ''test role'', ''goal'': ''test goal'', ''backstory'': ''test backstory'', ''cache'': True, ''verbose'': True, ''max_rpm'': None, ''allow_delegation'': False, ''tools'': [], ''max_iter'': 6, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=004dd8a0-dd87-43fa-bdc8-07f449808028, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''llm'': , ''llm'': , ''crew'': Crew(id=004dd8a0-dd87-43fa-bdc8-07f449808028, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': ", "llm": "", "llm": "", "crew": {"parent_flow": null, "name": "crew", "cache": true, "tasks": ["{''used_tools'': 5, ''tools_errors'': 0, ''delegations'': 0, ''i18n'': {''prompt_file'': None}, ''name'': None, ''prompt_context'': '''', @@ -2040,7 +2040,7 @@ interactions: ''test role'', ''goal'': ''test goal'', ''backstory'': ''test backstory'', ''cache'': True, ''verbose'': True, ''max_rpm'': None, ''allow_delegation'': False, ''tools'': [], ''max_iter'': 6, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=004dd8a0-dd87-43fa-bdc8-07f449808028, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''llm'': , ''llm'': , ''crew'': Crew(id=004dd8a0-dd87-43fa-bdc8-07f449808028, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': ", "llm": "", "llm": "", "crew": {"parent_flow": null, "name": "crew", "cache": true, "tasks": ["{''used_tools'': 0, ''tools_errors'': 0, ''delegations'': 0, ''i18n'': {''prompt_file'': None}, ''name'': None, ''prompt_context'': '''', @@ -1093,7 +1093,7 @@ interactions: ''test role'', ''goal'': ''test goal'', ''backstory'': ''test backstory'', ''cache'': True, ''verbose'': True, ''max_rpm'': 10, ''allow_delegation'': False, ''tools'': [], ''max_iter'': 4, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=4c6d502e-f6ec-446a-8f76-644563c4aa94, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''llm'': , ''llm'': , ''crew'': Crew(id=4c6d502e-f6ec-446a-8f76-644563c4aa94, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': ", "llm": "", "llm": "", "crew": {"parent_flow": null, "name": "crew", "cache": true, "tasks": ["{''used_tools'': 0, ''tools_errors'': 0, ''delegations'': 0, ''i18n'': {''prompt_file'': None}, ''name'': None, ''prompt_context'': '''', @@ -1921,7 +1921,7 @@ interactions: ''test role'', ''goal'': ''test goal'', ''backstory'': ''test backstory'', ''cache'': True, ''verbose'': True, ''max_rpm'': 10, ''allow_delegation'': False, ''tools'': [], ''max_iter'': 4, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=1a07d718-fed5-49fa-bee2-de2db91c9f33, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''llm'': , ''llm'': , ''crew'': Crew(id=1a07d718-fed5-49fa-bee2-de2db91c9f33, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': at 0x10614d3a0>", "result_as_answer": "False", "max_usage_count": "None", "current_usage_count": "0"}], "max_iter": 25, "agent_executor": "", "llm": "", "llm": "", "crew": {"parent_flow": null, "name": "crew", "cache": true, "tasks": ["{''used_tools'': 0, ''tools_errors'': 0, ''delegations'': 0, ''i18n'': {''prompt_file'': None}, ''name'': None, ''prompt_context'': '''', @@ -977,7 +977,7 @@ interactions: ''description_updated'': False, ''cache_function'': at 0x10614d3a0>, ''result_as_answer'': False, ''max_usage_count'': None, ''current_usage_count'': 0}], ''max_iter'': 25, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=49cbb747-f055-4636-bbca-9e8a450c05f6, process=Process.hierarchical, number_of_agents=2, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''llm'': , ''llm'': , ''crew'': Crew(id=49cbb747-f055-4636-bbca-9e8a450c05f6, process=Process.hierarchical, number_of_agents=2, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''llm'': , ''llm'': , ''crew'': Crew(id=49cbb747-f055-4636-bbca-9e8a450c05f6, process=Process.hierarchical, number_of_agents=2, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': ", "manager_agent": {"id": "UUID(''b0898472-5e3b-45bb-bd90-05bad0b5a8ce'')", "role": "''Crew Manager''", "goal": "''Manage the team to complete the task in the best way possible.''", "backstory": "\"You are a seasoned manager with @@ -1053,7 +1053,7 @@ interactions: ''description_updated'': False, ''cache_function'': at 0x10614d3a0>, ''result_as_answer'': False, ''max_usage_count'': None, ''current_usage_count'': 0}]", "max_iter": "25", "agent_executor": "", "llm": "", "llm": "", "crew": "Crew(id=49cbb747-f055-4636-bbca-9e8a450c05f6, process=Process.hierarchical, number_of_agents=2, number_of_tasks=1)", "i18n": "{''prompt_file'': None}", "cache_handler": "{}", "tools_handler": " at 0x107e394e0>", "result_as_answer": "False", "max_usage_count": "None", "current_usage_count": "0"}], "max_iter": 25, "agent_executor": "", "llm": "", "llm": "", "crew": {"parent_flow": null, "name": "crew", "cache": true, "tasks": ["{''used_tools'': 0, ''tools_errors'': 0, ''delegations'': 0, ''i18n'': {''prompt_file'': None}, ''name'': None, ''prompt_context'': '''', @@ -1845,7 +1845,7 @@ interactions: ''description_updated'': False, ''cache_function'': at 0x107e394e0>, ''result_as_answer'': False, ''max_usage_count'': None, ''current_usage_count'': 0}], ''max_iter'': 25, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=4d744f3e-0589-4d1d-b1c1-6aa8b52478ac, process=Process.hierarchical, number_of_agents=2, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''llm'': , ''llm'': , ''crew'': Crew(id=4d744f3e-0589-4d1d-b1c1-6aa8b52478ac, process=Process.hierarchical, number_of_agents=2, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''llm'': , ''llm'': , ''crew'': Crew(id=4d744f3e-0589-4d1d-b1c1-6aa8b52478ac, process=Process.hierarchical, number_of_agents=2, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': ", "manager_agent": {"id": "UUID(''09794b42-447f-4b7a-b634-3a861f457357'')", "role": "''Crew Manager''", "goal": "''Manage the team to complete the task in the best way possible.''", "backstory": "\"You are a seasoned manager with @@ -1921,7 +1921,7 @@ interactions: ''description_updated'': False, ''cache_function'': at 0x107e394e0>, ''result_as_answer'': False, ''max_usage_count'': None, ''current_usage_count'': 0}]", "max_iter": "25", "agent_executor": "", "llm": "", "llm": "", "crew": "Crew(id=4d744f3e-0589-4d1d-b1c1-6aa8b52478ac, process=Process.hierarchical, number_of_agents=2, number_of_tasks=1)", "i18n": "{''prompt_file'': None}", "cache_handler": "{}", "tools_handler": ", ''llm'': , ''llm'': , ''crew'': None, ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''tools_results'': [], ''max_tokens'': None, ''knowledge'': None, ''knowledge_sources'': @@ -149,7 +149,7 @@ interactions: writing content for a new customer.\", ''cache'': True, ''verbose'': False, ''max_rpm'': None, ''allow_delegation'': False, ''tools'': [], ''max_iter'': 25, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': None, ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''tools_results'': [], ''max_tokens'': None, ''knowledge'': None, ''knowledge_sources'': @@ -169,7 +169,7 @@ interactions: a freelancer and is now working on doing research and analysis for a new customer.\"", "cache": "True", "verbose": "False", "max_rpm": "None", "allow_delegation": "False", "tools": "[]", "max_iter": "25", "agent_executor": "", "llm": "", "llm": "", "crew": "None", "i18n": "{''prompt_file'': None}", "cache_handler": "{}", "tools_handler": "", "tools_results": "[]", "max_tokens": "None", "knowledge": @@ -182,7 +182,7 @@ interactions: You work as a freelancer and are now working on writing content for a new customer.\"", "cache": "True", "verbose": "False", "max_rpm": "None", "allow_delegation": "False", "tools": "[]", "max_iter": "25", "agent_executor": "", "llm": "", "llm": "", "crew": "None", "i18n": "{''prompt_file'': None}", "cache_handler": "{}", "tools_handler": "", "tools_results": "[]", "max_tokens": "None", "knowledge": @@ -214,7 +214,7 @@ interactions: a freelancer and is now working on doing research and analysis for a new customer.\", ''cache'': True, ''verbose'': False, ''max_rpm'': None, ''allow_delegation'': False, ''tools'': [], ''max_iter'': 25, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': None, ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''tools_results'': [], ''max_tokens'': None, ''knowledge'': None, ''knowledge_sources'': @@ -237,7 +237,7 @@ interactions: writing content for a new customer.\", ''cache'': True, ''verbose'': False, ''max_rpm'': None, ''allow_delegation'': False, ''tools'': [], ''max_iter'': 25, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': None, ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': , ''tools_results'': [], ''max_tokens'': None, ''knowledge'': None, ''knowledge_sources'': @@ -257,7 +257,7 @@ interactions: a freelancer and is now working on doing research and analysis for a new customer.\"", "cache": "True", "verbose": "False", "max_rpm": "None", "allow_delegation": "False", "tools": "[]", "max_iter": "25", "agent_executor": "", "llm": "", "llm": "", "crew": "None", "i18n": "{''prompt_file'': None}", "cache_handler": "{}", "tools_handler": "", "tools_results": "[]", "max_tokens": "None", "knowledge": @@ -270,7 +270,7 @@ interactions: You work as a freelancer and are now working on writing content for a new customer.\"", "cache": "True", "verbose": "False", "max_rpm": "None", "allow_delegation": "False", "tools": "[]", "max_iter": "25", "agent_executor": "", "llm": "", "llm": "", "crew": "None", "i18n": "{''prompt_file'': None}", "cache_handler": "{}", "tools_handler": "", "tools_results": "[]", "max_tokens": "None", "knowledge": diff --git a/lib/crewai/tests/cassettes/test_tool_result_as_answer_is_the_final_answer_for_the_agent.yaml b/lib/crewai/tests/cassettes/test_tool_result_as_answer_is_the_final_answer_for_the_agent.yaml index d3784b9e75..a0941cf37b 100644 --- a/lib/crewai/tests/cassettes/test_tool_result_as_answer_is_the_final_answer_for_the_agent.yaml +++ b/lib/crewai/tests/cassettes/test_tool_result_as_answer_is_the_final_answer_for_the_agent.yaml @@ -468,7 +468,7 @@ interactions: "description_updated": "False", "cache_function": " at 0x107ff9440>", "result_as_answer": "True", "max_usage_count": "None", "current_usage_count": "0"}], "max_iter": 25, "agent_executor": "", "llm": "", "llm": "", "crew": {"parent_flow": null, "name": "crew", "cache": true, "tasks": ["{''used_tools'': 0, ''tools_errors'': 0, ''delegations'': 0, ''i18n'': {''prompt_file'': None}, ''name'': None, ''prompt_context'': '''', @@ -484,7 +484,7 @@ interactions: , ''description_updated'': False, ''cache_function'': at 0x107ff9440>, ''result_as_answer'': True, ''max_usage_count'': None, ''current_usage_count'': 0}], ''max_iter'': 25, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=f74956dd-60d0-402a-a703-2cc3d767397f, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': at 0x107ff9440>, ''result_as_answer'': True, ''max_usage_count'': None, ''current_usage_count'': 0}], ''max_iter'': 25, ''agent_executor'': , ''llm'': , ''llm'': , ''crew'': Crew(id=f74956dd-60d0-402a-a703-2cc3d767397f, process=Process.sequential, number_of_agents=1, number_of_tasks=1), ''i18n'': {''prompt_file'': None}, ''cache_handler'': {}, ''tools_handler'': str: def test_azure_raises_error_when_endpoint_missing(): """Test that AzureCompletion raises ValueError when endpoint is missing""" - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion # Clear environment variables with patch.dict(os.environ, {}, clear=True): @@ -383,7 +383,7 @@ def test_azure_raises_error_when_endpoint_missing(): def test_azure_raises_error_when_api_key_missing(): """Test that AzureCompletion raises ValueError when API key is missing""" - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion # Clear environment variables with patch.dict(os.environ, {}, clear=True): @@ -400,7 +400,7 @@ def test_azure_endpoint_configuration(): }): llm = LLM(model="azure/gpt-4") - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion assert isinstance(llm, AzureCompletion) assert llm.endpoint == "https://test1.openai.azure.com/openai/deployments/gpt-4" @@ -426,7 +426,7 @@ def test_azure_api_key_configuration(): }): llm = LLM(model="azure/gpt-4") - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion assert isinstance(llm, AzureCompletion) assert llm.api_key == "test-azure-key" @@ -437,7 +437,7 @@ def test_azure_model_capabilities(): """ # Test GPT-4 model (supports function calling) llm_gpt4 = LLM(model="azure/gpt-4") - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion assert isinstance(llm_gpt4, AzureCompletion) assert llm_gpt4.is_openai_model == True assert llm_gpt4.supports_function_calling() == True @@ -466,7 +466,7 @@ def test_azure_completion_params_preparation(): max_tokens=1000 ) - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion assert isinstance(llm, AzureCompletion) messages = [{"role": "user", "content": "Hello"}] @@ -494,7 +494,7 @@ def test_azure_model_detection(): for model_name in azure_test_cases: llm = LLM(model=model_name) - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion assert isinstance(llm, AzureCompletion), f"Failed for model: {model_name}" @@ -662,7 +662,7 @@ def test_azure_streaming_completion(): """ Test that streaming completions work properly """ - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion from azure.ai.inference.models import StreamingChatCompletionsUpdate llm = LLM(model="azure/gpt-4", stream=True) @@ -698,7 +698,7 @@ def test_azure_api_version_default(): """ llm = LLM(model="azure/gpt-4") - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion assert isinstance(llm, AzureCompletion) # Should use default or environment variable assert llm.api_version is not None @@ -721,7 +721,7 @@ def test_azure_openai_endpoint_url_construction(): """ Test that Azure OpenAI endpoint URLs are automatically constructed correctly """ - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion with patch.dict(os.environ, { "AZURE_API_KEY": "test-key", @@ -738,7 +738,7 @@ def test_azure_openai_endpoint_url_with_trailing_slash(): """ Test that trailing slashes are handled correctly in endpoint URLs """ - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion with patch.dict(os.environ, { "AZURE_API_KEY": "test-key", @@ -804,7 +804,7 @@ def test_non_azure_openai_model_parameter_included(): """ Test that model parameter IS included for non-Azure OpenAI endpoints """ - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion with patch.dict(os.environ, { "AZURE_API_KEY": "test-key", @@ -824,7 +824,7 @@ def test_azure_message_formatting_with_role(): """ Test that messages are formatted with both 'role' and 'content' fields """ - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion llm = LLM(model="azure/gpt-4") @@ -886,7 +886,7 @@ def test_azure_improved_error_messages(): """ Test that improved error messages are provided for common HTTP errors """ - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion from azure.core.exceptions import HttpResponseError llm = LLM(model="azure/gpt-4") @@ -918,7 +918,7 @@ def test_azure_api_version_properly_passed(): """ Test that api_version is properly passed to the client """ - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion with patch.dict(os.environ, { "AZURE_API_KEY": "test-key", @@ -940,7 +940,7 @@ def test_azure_timeout_and_max_retries_stored(): """ Test that timeout and max_retries parameters are stored """ - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion with patch.dict(os.environ, { "AZURE_API_KEY": "test-key", @@ -960,7 +960,7 @@ def test_azure_complete_params_include_optional_params(): """ Test that optional parameters are included in completion params when set """ - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion with patch.dict(os.environ, { "AZURE_API_KEY": "test-key", @@ -992,7 +992,7 @@ def test_azure_endpoint_validation_with_azure_prefix(): """ Test that 'azure/' prefix is properly stripped when constructing endpoint """ - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion with patch.dict(os.environ, { "AZURE_API_KEY": "test-key", @@ -1009,7 +1009,7 @@ def test_azure_message_formatting_preserves_all_roles(): """ Test that all message roles (system, user, assistant) are preserved correctly """ - from crewai.llms.providers.azure.completion import AzureCompletion + from crewai.llm.providers.azure.completion import AzureCompletion llm = LLM(model="azure/gpt-4") diff --git a/lib/crewai/tests/llms/bedrock/test_bedrock.py b/lib/crewai/tests/llms/bedrock/test_bedrock.py index aecbdde0ee..7ad7c20800 100644 --- a/lib/crewai/tests/llms/bedrock/test_bedrock.py +++ b/lib/crewai/tests/llms/bedrock/test_bedrock.py @@ -19,7 +19,7 @@ def mock_aws_credentials(): "AWS_DEFAULT_REGION": "us-east-1" }): # Mock boto3 Session to prevent actual AWS connections - with patch('crewai.llms.providers.bedrock.completion.Session') as mock_session_class: + with patch('crewai.llm.providers.bedrock.completion.Session') as mock_session_class: # Create mock session instance mock_session_instance = MagicMock() mock_client = MagicMock() @@ -67,7 +67,7 @@ def test_bedrock_completion_module_is_imported(): """ Test that the completion module is properly imported when using Bedrock provider """ - module_name = "crewai.llms.providers.bedrock.completion" + module_name = "crewai.llm.providers.bedrock.completion" # Remove module from cache if it exists if module_name in sys.modules: @@ -124,7 +124,7 @@ def test_bedrock_completion_initialization_parameters(): region_name="us-west-2" ) - from crewai.llms.providers.bedrock.completion import BedrockCompletion + from crewai.llm.providers.bedrock.completion import BedrockCompletion assert isinstance(llm, BedrockCompletion) assert llm.model == "anthropic.claude-3-5-sonnet-20241022-v2:0" assert llm.temperature == 0.7 @@ -145,7 +145,7 @@ def test_bedrock_specific_parameters(): region_name="us-east-1" ) - from crewai.llms.providers.bedrock.completion import BedrockCompletion + from crewai.llm.providers.bedrock.completion import BedrockCompletion assert isinstance(llm, BedrockCompletion) assert llm.stop_sequences == ["Human:", "Assistant:"] assert llm.stream == True @@ -369,7 +369,7 @@ def test_bedrock_aws_credentials_configuration(): }): llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") - from crewai.llms.providers.bedrock.completion import BedrockCompletion + from crewai.llm.providers.bedrock.completion import BedrockCompletion assert isinstance(llm, BedrockCompletion) assert llm.region_name == "us-east-1" @@ -390,7 +390,7 @@ def test_bedrock_model_capabilities(): """ # Test Claude model llm_claude = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") - from crewai.llms.providers.bedrock.completion import BedrockCompletion + from crewai.llm.providers.bedrock.completion import BedrockCompletion assert isinstance(llm_claude, BedrockCompletion) assert llm_claude.is_claude_model == True assert llm_claude.supports_tools == True @@ -413,7 +413,7 @@ def test_bedrock_inference_config(): max_tokens=1000 ) - from crewai.llms.providers.bedrock.completion import BedrockCompletion + from crewai.llm.providers.bedrock.completion import BedrockCompletion assert isinstance(llm, BedrockCompletion) # Test config preparation @@ -444,7 +444,7 @@ def test_bedrock_model_detection(): for model_name in bedrock_test_cases: llm = LLM(model=model_name) - from crewai.llms.providers.bedrock.completion import BedrockCompletion + from crewai.llm.providers.bedrock.completion import BedrockCompletion assert isinstance(llm, BedrockCompletion), f"Failed for model: {model_name}" diff --git a/lib/crewai/tests/llms/google/test_google.py b/lib/crewai/tests/llms/google/test_google.py index c6f271b0a5..ffb070b5e5 100644 --- a/lib/crewai/tests/llms/google/test_google.py +++ b/lib/crewai/tests/llms/google/test_google.py @@ -34,7 +34,7 @@ def test_gemini_completion_is_used_when_gemini_provider(): """ llm = LLM(model="gemini/gemini-2.0-flash-001") - from crewai.llms.providers.gemini.completion import GeminiCompletion + from crewai.llm.providers.gemini.completion import GeminiCompletion assert isinstance(llm, GeminiCompletion) assert llm.provider == "gemini" assert llm.model == "gemini-2.0-flash-001" @@ -47,7 +47,7 @@ def test_gemini_tool_use_conversation_flow(): Test that the Gemini completion properly handles tool use conversation flow """ from unittest.mock import Mock, patch - from crewai.llms.providers.gemini.completion import GeminiCompletion + from crewai.llm.providers.gemini.completion import GeminiCompletion # Create GeminiCompletion instance completion = GeminiCompletion(model="gemini-2.0-flash-001") @@ -102,7 +102,7 @@ def test_gemini_completion_module_is_imported(): """ Test that the completion module is properly imported when using Google provider """ - module_name = "crewai.llms.providers.gemini.completion" + module_name = "crewai.llm.providers.gemini.completion" # Remove module from cache if it exists if module_name in sys.modules: @@ -159,7 +159,7 @@ def test_gemini_completion_initialization_parameters(): api_key="test-key" ) - from crewai.llms.providers.gemini.completion import GeminiCompletion + from crewai.llm.providers.gemini.completion import GeminiCompletion assert isinstance(llm, GeminiCompletion) assert llm.model == "gemini-2.0-flash-001" assert llm.temperature == 0.7 @@ -186,7 +186,7 @@ def test_gemini_specific_parameters(): location="us-central1" ) - from crewai.llms.providers.gemini.completion import GeminiCompletion + from crewai.llm.providers.gemini.completion import GeminiCompletion assert isinstance(llm, GeminiCompletion) assert llm.stop_sequences == ["Human:", "Assistant:"] assert llm.stream == True @@ -382,7 +382,7 @@ def test_gemini_raises_error_when_model_not_supported(): """Test that GeminiCompletion raises ValueError when model not supported""" # Mock the Google client to raise an error - with patch('crewai.llms.providers.gemini.completion.genai') as mock_genai: + with patch('crewai.llm.providers.gemini.completion.genai') as mock_genai: mock_client = MagicMock() mock_genai.Client.return_value = mock_client @@ -420,7 +420,7 @@ def test_gemini_vertex_ai_setup(): location="us-west1" ) - from crewai.llms.providers.gemini.completion import GeminiCompletion + from crewai.llm.providers.gemini.completion import GeminiCompletion assert isinstance(llm, GeminiCompletion) assert llm.project == "test-project" @@ -435,7 +435,7 @@ def test_gemini_api_key_configuration(): with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}): llm = LLM(model="google/gemini-2.0-flash-001") - from crewai.llms.providers.gemini.completion import GeminiCompletion + from crewai.llm.providers.gemini.completion import GeminiCompletion assert isinstance(llm, GeminiCompletion) assert llm.api_key == "test-google-key" @@ -453,7 +453,7 @@ def test_gemini_model_capabilities(): """ # Test Gemini 2.0 model llm_2_0 = LLM(model="google/gemini-2.0-flash-001") - from crewai.llms.providers.gemini.completion import GeminiCompletion + from crewai.llm.providers.gemini.completion import GeminiCompletion assert isinstance(llm_2_0, GeminiCompletion) assert llm_2_0.is_gemini_2 == True assert llm_2_0.supports_tools == True @@ -477,7 +477,7 @@ def test_gemini_generation_config(): max_output_tokens=1000 ) - from crewai.llms.providers.gemini.completion import GeminiCompletion + from crewai.llm.providers.gemini.completion import GeminiCompletion assert isinstance(llm, GeminiCompletion) # Test config preparation @@ -504,7 +504,7 @@ def test_gemini_model_detection(): for model_name in gemini_test_cases: llm = LLM(model=model_name) - from crewai.llms.providers.gemini.completion import GeminiCompletion + from crewai.llm.providers.gemini.completion import GeminiCompletion assert isinstance(llm, GeminiCompletion), f"Failed for model: {model_name}" diff --git a/lib/crewai/tests/llms/hooks/test_anthropic_interceptor.py b/lib/crewai/tests/llms/hooks/test_anthropic_interceptor.py index 4002d7b974..f606952044 100644 --- a/lib/crewai/tests/llms/hooks/test_anthropic_interceptor.py +++ b/lib/crewai/tests/llms/hooks/test_anthropic_interceptor.py @@ -6,7 +6,7 @@ import pytest from crewai.llm import LLM -from crewai.llms.hooks.base import BaseInterceptor +from crewai.llm.hooks.base import BaseInterceptor @pytest.fixture(autouse=True) diff --git a/lib/crewai/tests/llms/hooks/test_base_interceptor.py b/lib/crewai/tests/llms/hooks/test_base_interceptor.py index a216ddfcf0..b6be5c331a 100644 --- a/lib/crewai/tests/llms/hooks/test_base_interceptor.py +++ b/lib/crewai/tests/llms/hooks/test_base_interceptor.py @@ -3,7 +3,7 @@ import httpx import pytest -from crewai.llms.hooks.base import BaseInterceptor +from crewai.llm.hooks.base import BaseInterceptor class SimpleInterceptor(BaseInterceptor[httpx.Request, httpx.Response]): diff --git a/lib/crewai/tests/llms/hooks/test_openai_interceptor.py b/lib/crewai/tests/llms/hooks/test_openai_interceptor.py index 9c3b8537ef..7d94afbf8f 100644 --- a/lib/crewai/tests/llms/hooks/test_openai_interceptor.py +++ b/lib/crewai/tests/llms/hooks/test_openai_interceptor.py @@ -4,7 +4,7 @@ import pytest from crewai.llm import LLM -from crewai.llms.hooks.base import BaseInterceptor +from crewai.llm.hooks.base import BaseInterceptor class OpenAITestInterceptor(BaseInterceptor[httpx.Request, httpx.Response]): diff --git a/lib/crewai/tests/llms/hooks/test_transport.py b/lib/crewai/tests/llms/hooks/test_transport.py index 5ff5162bd6..785e690429 100644 --- a/lib/crewai/tests/llms/hooks/test_transport.py +++ b/lib/crewai/tests/llms/hooks/test_transport.py @@ -5,8 +5,8 @@ import httpx import pytest -from crewai.llms.hooks.base import BaseInterceptor -from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport +from crewai.llm.hooks.base import BaseInterceptor +from crewai.llm.hooks.transport import AsyncHTTPTransport, HTTPTransport class TrackingInterceptor(BaseInterceptor[httpx.Request, httpx.Response]): diff --git a/lib/crewai/tests/llms/hooks/test_unsupported_providers.py b/lib/crewai/tests/llms/hooks/test_unsupported_providers.py index bacc9878f2..be682e850d 100644 --- a/lib/crewai/tests/llms/hooks/test_unsupported_providers.py +++ b/lib/crewai/tests/llms/hooks/test_unsupported_providers.py @@ -6,7 +6,7 @@ import pytest from crewai.llm import LLM -from crewai.llms.hooks.base import BaseInterceptor +from crewai.llm.hooks.base import BaseInterceptor @pytest.fixture(autouse=True) diff --git a/lib/crewai/tests/llms/openai/test_openai.py b/lib/crewai/tests/llms/openai/test_openai.py index ee393ba9bd..aee167ab57 100644 --- a/lib/crewai/tests/llms/openai/test_openai.py +++ b/lib/crewai/tests/llms/openai/test_openai.py @@ -6,7 +6,7 @@ import pytest from crewai.llm import LLM -from crewai.llms.providers.openai.completion import OpenAICompletion +from crewai.llm.providers.openai.completion import OpenAICompletion from crewai.crew import Crew from crewai.agent import Agent from crewai.task import Task @@ -29,7 +29,7 @@ def test_openai_completion_is_used_when_no_provider_prefix(): """ llm = LLM(model="gpt-4o") - from crewai.llms.providers.openai.completion import OpenAICompletion + from crewai.llm.providers.openai.completion import OpenAICompletion assert isinstance(llm, OpenAICompletion) assert llm.provider == "openai" assert llm.model == "gpt-4o" @@ -63,7 +63,7 @@ def test_openai_completion_module_is_imported(): """ Test that the completion module is properly imported when using OpenAI provider """ - module_name = "crewai.llms.providers.openai.completion" + module_name = "crewai.llm.providers.openai.completion" # Remove module from cache if it exists if module_name in sys.modules: @@ -114,7 +114,7 @@ def test_openai_completion_initialization_parameters(): api_key="test-key" ) - from crewai.llms.providers.openai.completion import OpenAICompletion + from crewai.llm.providers.openai.completion import OpenAICompletion assert isinstance(llm, OpenAICompletion) assert llm.model == "gpt-4o" assert llm.temperature == 0.7 @@ -335,7 +335,7 @@ def test_openai_completion_call_returns_usage_metrics(): def test_openai_raises_error_when_model_not_supported(): """Test that OpenAICompletion raises ValueError when model not supported""" - with patch('crewai.llms.providers.openai.completion.OpenAI') as mock_openai_class: + with patch('crewai.llm.providers.openai.completion.OpenAI') as mock_openai_class: mock_client = MagicMock() mock_openai_class.return_value = mock_client diff --git a/lib/crewai/tests/test_custom_llm.py b/lib/crewai/tests/test_custom_llm.py index fef1bb5b53..770d5607b3 100644 --- a/lib/crewai/tests/test_custom_llm.py +++ b/lib/crewai/tests/test_custom_llm.py @@ -2,7 +2,7 @@ import pytest from crewai import Agent, Crew, Process, Task -from crewai.llms.base_llm import BaseLLM +from crewai.llm.base_llm import BaseLLM from crewai.utilities.llm_utils import create_llm diff --git a/lib/crewai/tests/utilities/test_events.py b/lib/crewai/tests/utilities/test_events.py index 1eeba199a5..7c317a64f9 100644 --- a/lib/crewai/tests/utilities/test_events.py +++ b/lib/crewai/tests/utilities/test_events.py @@ -743,7 +743,7 @@ def handle_llm_call_failed(source, event): error_message = "OpenAI API call failed: Simulated API failure" with patch( - "crewai.llms.providers.openai.completion.OpenAICompletion._handle_completion" + "crewai.llm.providers.openai.completion.OpenAICompletion._handle_completion" ) as mock_handle_completion: mock_handle_completion.side_effect = Exception("Simulated API failure") diff --git a/lib/crewai/tests/utilities/test_llm_utils.py b/lib/crewai/tests/utilities/test_llm_utils.py index e02173f8d9..b29bc36de2 100644 --- a/lib/crewai/tests/utilities/test_llm_utils.py +++ b/lib/crewai/tests/utilities/test_llm_utils.py @@ -4,7 +4,7 @@ from crewai.cli.constants import DEFAULT_LLM_MODEL from crewai.llm import LLM -from crewai.llms.base_llm import BaseLLM +from crewai.llm.base_llm import BaseLLM from crewai.utilities.llm_utils import create_llm import pytest diff --git a/logs.txt b/logs.txt new file mode 100644 index 0000000000..6aaa37c3a2 --- /dev/null +++ b/logs.txt @@ -0,0 +1,20 @@ +lib/crewai/src/crewai/agent/core.py:901: error: Argument 1 has incompatible type "ToolFilterContext"; expected "dict[str, Any]" [arg-type] +lib/crewai/src/crewai/agent/core.py:901: note: Error code "arg-type" not covered by "type: ignore" comment +lib/crewai/src/crewai/agent/core.py:905: error: Argument 1 has incompatible type "dict[str, Any]"; expected "ToolFilterContext" [arg-type] +lib/crewai/src/crewai/agent/core.py:905: note: Error code "arg-type" not covered by "type: ignore" comment +lib/crewai/src/crewai/agent/core.py:996: error: Returning Any from function declared to return "dict[str, dict[str, Any]]" [no-any-return] +lib/crewai/src/crewai/agent/core.py:1157: error: Incompatible types in assignment (expression has type "tuple[UnionType, None]", target has type "tuple[type, Any]") [assignment] +lib/crewai/src/crewai/agent/core.py:1183: error: Argument 1 to "append" of "list" has incompatible type "type"; expected "type[str]" [arg-type] +lib/crewai/src/crewai/agent/core.py:1188: error: Incompatible types in assignment (expression has type "UnionType", variable has type "type[str]") [assignment] +lib/crewai/src/crewai/agent/core.py:1201: error: Argument 1 to "get" of "dict" has incompatible type "Any | None"; expected "str" [arg-type] +Found 7 errors in 1 file (checked 4 source files) +Success: no issues found in 4 source files +lib/crewai/src/crewai/llm/providers/gemini/completion.py:111: error: BaseModel field may only be overridden by another field [misc] +Found 1 error in 1 file (checked 4 source files) +Success: no issues found in 4 source files +lib/crewai/src/crewai/llm/providers/anthropic/completion.py:101: error: BaseModel field may only be overridden by another field [misc] +Found 1 error in 1 file (checked 4 source files) +lib/crewai/src/crewai/llm/providers/bedrock/completion.py:250: error: BaseModel field may only be overridden by another field [misc] +Found 1 error in 1 file (checked 4 source files) + +uv-lock..............................................(no files to check)Skipped From d8fe83f76c4bf48b18be1f3013084f18c0671ff2 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Mon, 10 Nov 2025 16:05:23 -0500 Subject: [PATCH 2/7] chore: continue refactoring llms to base models --- lib/crewai/src/crewai/__init__.py | 2 +- lib/crewai/src/crewai/cli/crew_chat.py | 3 +- lib/crewai/src/crewai/crew.py | 2 +- .../src/crewai/events/event_listener.py | 2 +- .../experimental/evaluation/base_evaluator.py | 2 +- lib/crewai/src/crewai/lite_agent.py | 2 +- lib/crewai/src/crewai/llm/base_llm.py | 16 +- lib/crewai/src/crewai/llm/internal/meta.py | 31 ++- .../llm/providers/anthropic/completion.py | 145 +++++------- .../crewai/llm/providers/azure/completion.py | 152 +++++++------ .../llm/providers/bedrock/completion.py | 210 ++++++++--------- .../crewai/llm/providers/gemini/completion.py | 211 +++++++++--------- .../crewai/llm/providers/openai/completion.py | 20 +- .../crewai/tasks/hallucination_guardrail.py | 2 +- lib/crewai/src/crewai/tools/tool_usage.py | 2 +- lib/crewai/tests/test_llm.py | 14 +- logs.txt | 20 -- 17 files changed, 393 insertions(+), 443 deletions(-) delete mode 100644 logs.txt diff --git a/lib/crewai/src/crewai/__init__.py b/lib/crewai/src/crewai/__init__.py index ef2bcf78d2..4e2365a2ff 100644 --- a/lib/crewai/src/crewai/__init__.py +++ b/lib/crewai/src/crewai/__init__.py @@ -8,8 +8,8 @@ from crewai.crews.crew_output import CrewOutput from crewai.flow.flow import Flow from crewai.knowledge.knowledge import Knowledge -from crewai.llm import LLM from crewai.llm.base_llm import BaseLLM +from crewai.llm.core import LLM from crewai.process import Process from crewai.task import Task from crewai.tasks.llm_guardrail import LLMGuardrail diff --git a/lib/crewai/src/crewai/cli/crew_chat.py b/lib/crewai/src/crewai/cli/crew_chat.py index feca9e4ca3..f593180864 100644 --- a/lib/crewai/src/crewai/cli/crew_chat.py +++ b/lib/crewai/src/crewai/cli/crew_chat.py @@ -14,7 +14,8 @@ from crewai.cli.utils import read_toml from crewai.cli.version import get_crewai_version from crewai.crew import Crew -from crewai.llm import LLM, BaseLLM +from crewai.llm import LLM +from crewai.llm.base_llm import BaseLLM from crewai.types.crew_chat import ChatInputField, ChatInputs from crewai.utilities.llm_utils import create_llm from crewai.utilities.printer import Printer diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 31eb7466ca..b258f3eaa6 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -56,8 +56,8 @@ from crewai.flow.flow_trackable import FlowTrackable from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource -from crewai.llm import LLM from crewai.llm.base_llm import BaseLLM +from crewai.llm.core import LLM from crewai.memory.entity.entity_memory import EntityMemory from crewai.memory.external.external_memory import ExternalMemory from crewai.memory.long_term.long_term_memory import LongTermMemory diff --git a/lib/crewai/src/crewai/events/event_listener.py b/lib/crewai/src/crewai/events/event_listener.py index e07ee193ce..0a88e85f03 100644 --- a/lib/crewai/src/crewai/events/event_listener.py +++ b/lib/crewai/src/crewai/events/event_listener.py @@ -89,7 +89,7 @@ ToolUsageStartedEvent, ) from crewai.events.utils.console_formatter import ConsoleFormatter -from crewai.llm import LLM +from crewai.llm.core import LLM from crewai.task import Task from crewai.telemetry.telemetry import Telemetry from crewai.utilities import Logger diff --git a/lib/crewai/src/crewai/experimental/evaluation/base_evaluator.py b/lib/crewai/src/crewai/experimental/evaluation/base_evaluator.py index 69d1bb5c3b..8001074d33 100644 --- a/lib/crewai/src/crewai/experimental/evaluation/base_evaluator.py +++ b/lib/crewai/src/crewai/experimental/evaluation/base_evaluator.py @@ -7,7 +7,7 @@ from crewai.agent import Agent from crewai.agents.agent_builder.base_agent import BaseAgent -from crewai.llm import BaseLLM +from crewai.llm.base_llm import BaseLLM from crewai.task import Task from crewai.utilities.llm_utils import create_llm diff --git a/lib/crewai/src/crewai/lite_agent.py b/lib/crewai/src/crewai/lite_agent.py index ef877e01b6..e91e6b98f4 100644 --- a/lib/crewai/src/crewai/lite_agent.py +++ b/lib/crewai/src/crewai/lite_agent.py @@ -39,8 +39,8 @@ from crewai.events.types.logging_events import AgentLogsExecutionEvent from crewai.flow.flow_trackable import FlowTrackable from crewai.lite_agent_output import LiteAgentOutput -from crewai.llm import LLM from crewai.llm.base_llm import BaseLLM +from crewai.llm.core import LLM from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool from crewai.utilities.agent_utils import ( diff --git a/lib/crewai/src/crewai/llm/base_llm.py b/lib/crewai/src/crewai/llm/base_llm.py index 1222763457..f60ce500e6 100644 --- a/lib/crewai/src/crewai/llm/base_llm.py +++ b/lib/crewai/src/crewai/llm/base_llm.py @@ -66,7 +66,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): """ model_config: ClassVar[ConfigDict] = ConfigDict( - arbitrary_types_allowed=True, extra="allow", validate_assignment=True + arbitrary_types_allowed=True, extra="allow" ) # Core fields @@ -80,7 +80,9 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): default="openai", description="Provider name (openai, anthropic, etc.)" ) stop: list[str] = Field( - default_factory=list, description="Stop sequences for generation" + default_factory=list, + description="Stop sequences for generation", + validation_alias="stop_sequences", ) # Internal fields @@ -112,16 +114,18 @@ def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]: if not values.get("model"): raise ValueError("Model name is required and cannot be empty") - # Handle stop sequences - stop = values.get("stop") + stop = values.get("stop") or values.get("stop_sequences") if stop is None: values["stop"] = [] elif isinstance(stop, str): values["stop"] = [stop] - elif not isinstance(stop, list): + elif isinstance(stop, list): + values["stop"] = stop + else: values["stop"] = [] - # Set default provider if not specified + values.pop("stop_sequences", None) + if "provider" not in values or values["provider"] is None: values["provider"] = "openai" diff --git a/lib/crewai/src/crewai/llm/internal/meta.py b/lib/crewai/src/crewai/llm/internal/meta.py index 4bf83b655a..c12b1e9cc3 100644 --- a/lib/crewai/src/crewai/llm/internal/meta.py +++ b/lib/crewai/src/crewai/llm/internal/meta.py @@ -33,13 +33,12 @@ class LLMMeta(ModelMetaclass): native provider implementation based on the model parameter. """ - def __call__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> Any: # noqa: N805 + def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805 """Route to appropriate provider implementation at instantiation time. Args: - model: The model identifier (e.g., "gpt-4", "claude-3-opus") - is_litellm: Force use of LiteLLM instead of native provider - **kwargs: Additional parameters for the LLM + *args: Positional arguments (model should be first for LLM class) + **kwargs: Keyword arguments including model, is_litellm, etc. Returns: Instance of the appropriate provider class or LLM class @@ -47,18 +46,18 @@ def __call__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> Any: Raises: ValueError: If model is not a valid string """ - if not model or not isinstance(model, str): - raise ValueError("Model must be a non-empty string") + if cls.__name__ != "LLM": + return super().__call__(*args, **kwargs) - # Only perform routing if called on the base LLM class - # Subclasses (OpenAICompletion, etc.) should create normally - from crewai.llm import LLM + model = kwargs.get("model") or (args[0] if args else None) + is_litellm = kwargs.get("is_litellm", False) - if cls is not LLM: - # Direct instantiation of provider class, skip routing - return super().__call__(model=model, **kwargs) + if not model or not isinstance(model, str): + raise ValueError("Model must be a non-empty string") - # Extract provider information + if args and not kwargs.get("model"): + kwargs["model"] = args[0] + args = args[1:] explicit_provider = kwargs.get("provider") if explicit_provider: @@ -97,12 +96,10 @@ def __call__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> Any: use_native = True model_string = model - # Route to native provider if available native_class = cls._get_native_provider(provider) if use_native else None if native_class and not is_litellm and provider in SUPPORTED_NATIVE_PROVIDERS: try: - # Remove 'provider' from kwargs to avoid duplicate keyword argument - kwargs_copy = {k: v for k, v in kwargs.items() if k != "provider"} + kwargs_copy = {k: v for k, v in kwargs.items() if k not in ("provider", "model")} return native_class( model=model_string, provider=provider, **kwargs_copy ) @@ -111,14 +108,12 @@ def __call__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> Any: except Exception as e: raise ImportError(f"Error importing native provider: {e}") from e - # Fallback to LiteLLM try: import litellm # noqa: F401 except ImportError: logging.error("LiteLLM is not available, falling back to LiteLLM") raise ImportError("Fallback to LiteLLM is not available") from None - # Create actual LLM instance with is_litellm=True return super().__call__(model=model, is_litellm=True, **kwargs) @staticmethod diff --git a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py index 8ebba16737..dd13c0f5e1 100644 --- a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py @@ -3,9 +3,10 @@ import json import logging import os -from typing import TYPE_CHECKING, Any, ClassVar, cast +from typing import Any, ClassVar, cast -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType from crewai.llm.base_llm import BaseLLM @@ -19,9 +20,6 @@ from crewai.utilities.types import LLMMessage -if TYPE_CHECKING: - from crewai.llm.hooks.base import BaseInterceptor - try: from anthropic import Anthropic from anthropic.types import Message @@ -38,90 +36,67 @@ class AnthropicCompletion(BaseLLM): This class provides direct integration with the Anthropic Python SDK, offering native tool use, streaming support, and proper message formatting. + + Attributes: + model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022') + base_url: Custom base URL for Anthropic API + timeout: Request timeout in seconds + max_retries: Maximum number of retries + max_tokens: Maximum tokens in response (required for Anthropic) + top_p: Nucleus sampling parameter + stream: Enable streaming responses + client_params: Additional parameters for the Anthropic client + interceptor: HTTP interceptor for modifying requests/responses at transport level """ - model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,)) + model_config: ClassVar[ConfigDict] = ConfigDict( + ignored_types=(property,), arbitrary_types_allowed=True + ) + + base_url: str | None = Field( + default=None, description="Custom base URL for Anthropic API" + ) + timeout: float | None = Field( + default=None, description="Request timeout in seconds" + ) + max_retries: int = Field(default=2, description="Maximum number of retries") + max_tokens: int = Field( + default=4096, description="Maximum tokens in response (required for Anthropic)" + ) + top_p: float | None = Field(default=None, description="Nucleus sampling parameter") + stream: bool = Field(default=False, description="Enable streaming responses") + client_params: dict[str, Any] | None = Field( + default=None, description="Additional Anthropic client parameters" + ) + interceptor: Any = Field( + default=None, description="HTTP interceptor for request/response modification" + ) + client: Any = Field( + default=None, exclude=True, description="Anthropic client instance" + ) + + _is_claude_3: bool = PrivateAttr(default=False) + _supports_tools: bool = PrivateAttr(default=False) + + @model_validator(mode="after") + def setup_client(self) -> Self: + """Initialize the Anthropic client and model-specific settings.""" + self.client = Anthropic(**self._get_client_params()) - def __init__( - self, - model: str = "claude-3-5-sonnet-20241022", - api_key: str | None = None, - base_url: str | None = None, - timeout: float | None = None, - max_retries: int = 2, - temperature: float | None = None, - max_tokens: int = 4096, # Required for Anthropic - top_p: float | None = None, - stop_sequences: list[str] | None = None, - stream: bool = False, - client_params: dict[str, Any] | None = None, - interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, - **kwargs: Any, - ): - """Initialize Anthropic chat completion client. + self._is_claude_3 = "claude-3" in self.model.lower() + self._supports_tools = self._is_claude_3 - Args: - model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022') - api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var) - base_url: Custom base URL for Anthropic API - timeout: Request timeout in seconds - max_retries: Maximum number of retries - temperature: Sampling temperature (0-1) - max_tokens: Maximum tokens in response (required for Anthropic) - top_p: Nucleus sampling parameter - stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop) - stream: Enable streaming responses - client_params: Additional parameters for the Anthropic client - interceptor: HTTP interceptor for modifying requests/responses at transport level. - **kwargs: Additional parameters - """ - super().__init__( - model=model, temperature=temperature, stop=stop_sequences or [], **kwargs - ) + return self - # Client params - self.interceptor = interceptor - self.client_params = client_params - self.base_url = base_url - self.timeout = timeout - self.max_retries = max_retries - - self.client = Anthropic(**self._get_client_params()) + @property + def is_claude_3(self) -> bool: + """Check if model is Claude 3.""" + return self._is_claude_3 - # Store completion parameters - self.max_tokens = max_tokens - self.top_p = top_p - self.stream = stream - self.stop_sequences = stop_sequences or [] - - # Model-specific settings - self.is_claude_3 = "claude-3" in model.lower() - self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use - - # - # @property - # def stop(self) -> list[str]: # type: ignore[misc] - # """Get stop sequences sent to the API.""" - # return self.stop_sequences - - # @stop.setter - # def stop(self, value: list[str] | str | None) -> None: - # """Set stop sequences. - # - # Synchronizes stop_sequences to ensure values set by CrewAgentExecutor - # are properly sent to the Anthropic API. - # - # Args: - # value: Stop sequences as a list, single string, or None - # """ - # if value is None: - # self.stop_sequences = [] - # elif isinstance(value, str): - # self.stop_sequences = [value] - # elif isinstance(value, list): - # self.stop_sequences = value - # else: - # self.stop_sequences = [] + @property + def supports_tools(self) -> bool: + """Check if model supports tools.""" + return self._supports_tools def _get_client_params(self) -> dict[str, Any]: """Get client parameters.""" @@ -250,8 +225,8 @@ def _prepare_completion_params( params["temperature"] = self.temperature if self.top_p is not None: params["top_p"] = self.top_p - if self.stop_sequences: - params["stop_sequences"] = self.stop_sequences + if self.stop: + params["stop_sequences"] = self.stop # Handle tools for Claude 3+ if tools and self.supports_tools: diff --git a/lib/crewai/src/crewai/llm/providers/azure/completion.py b/lib/crewai/src/crewai/llm/providers/azure/completion.py index a389c18255..6076b9f4d1 100644 --- a/lib/crewai/src/crewai/llm/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llm/providers/azure/completion.py @@ -3,9 +3,10 @@ import json import logging import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from typing_extensions import Self from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES from crewai.llm.providers.utils.common import safe_tool_conversion @@ -17,7 +18,6 @@ if TYPE_CHECKING: - from crewai.llm.hooks.base import BaseInterceptor from crewai.tools.base_tool import BaseTool @@ -51,65 +51,77 @@ class AzureCompletion(BaseLLM): This class provides direct integration with the Azure AI Inference Python SDK, offering native function calling, streaming support, and proper Azure authentication. + + Attributes: + model: Azure deployment name or model name + endpoint: Azure endpoint URL + api_version: Azure API version + timeout: Request timeout in seconds + max_retries: Maximum number of retries + top_p: Nucleus sampling parameter + frequency_penalty: Frequency penalty (-2 to 2) + presence_penalty: Presence penalty (-2 to 2) + max_tokens: Maximum tokens in response + stream: Enable streaming responses + interceptor: HTTP interceptor (not yet supported for Azure) """ - def __init__( - self, - model: str, - api_key: str | None = None, - endpoint: str | None = None, - api_version: str | None = None, - timeout: float | None = None, - max_retries: int = 2, - temperature: float | None = None, - top_p: float | None = None, - frequency_penalty: float | None = None, - presence_penalty: float | None = None, - max_tokens: int | None = None, - stop: list[str] | None = None, - stream: bool = False, - interceptor: BaseInterceptor[Any, Any] | None = None, - **kwargs: Any, - ): - """Initialize Azure AI Inference chat completion client. + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) - Args: - model: Azure deployment name or model name - api_key: Azure API key (defaults to AZURE_API_KEY env var) - endpoint: Azure endpoint URL (defaults to AZURE_ENDPOINT env var) - api_version: Azure API version (defaults to AZURE_API_VERSION env var) - timeout: Request timeout in seconds - max_retries: Maximum number of retries - temperature: Sampling temperature (0-2) - top_p: Nucleus sampling parameter - frequency_penalty: Frequency penalty (-2 to 2) - presence_penalty: Presence penalty (-2 to 2) - max_tokens: Maximum tokens in response - stop: Stop sequences - stream: Enable streaming responses - interceptor: HTTP interceptor (not yet supported for Azure). - **kwargs: Additional parameters - """ - if interceptor is not None: + endpoint: str | None = Field( + default=None, + description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)", + ) + api_version: str = Field( + default="2024-06-01", + description="Azure API version (defaults to AZURE_API_VERSION env var or 2024-06-01)", + ) + timeout: float | None = Field( + default=None, description="Request timeout in seconds" + ) + max_retries: int = Field(default=2, description="Maximum number of retries") + top_p: float | None = Field(default=None, description="Nucleus sampling parameter") + frequency_penalty: float | None = Field( + default=None, description="Frequency penalty (-2 to 2)" + ) + presence_penalty: float | None = Field( + default=None, description="Presence penalty (-2 to 2)" + ) + max_tokens: int | None = Field( + default=None, description="Maximum tokens in response" + ) + stream: bool = Field(default=False, description="Enable streaming responses") + interceptor: Any = Field( + default=None, description="HTTP interceptor (not yet supported for Azure)" + ) + client: Any = Field(default=None, exclude=True, description="Azure client instance") + + _is_openai_model: bool = PrivateAttr(default=False) + _is_azure_openai_endpoint: bool = PrivateAttr(default=False) + + @model_validator(mode="after") + def setup_client(self) -> Self: + """Initialize the Azure client and validate configuration.""" + if self.interceptor is not None: raise NotImplementedError( "HTTP interceptors are not yet supported for Azure AI Inference provider. " "Interceptors are currently supported for OpenAI and Anthropic providers only." ) - super().__init__( - model=model, temperature=temperature, stop=stop or [], **kwargs - ) + if self.api_key is None: + self.api_key = os.getenv("AZURE_API_KEY") - self.api_key = api_key or os.getenv("AZURE_API_KEY") - self.endpoint = ( - endpoint - or os.getenv("AZURE_ENDPOINT") - or os.getenv("AZURE_OPENAI_ENDPOINT") - or os.getenv("AZURE_API_BASE") - ) - self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01" - self.timeout = timeout - self.max_retries = max_retries + if self.endpoint is None: + self.endpoint = ( + os.getenv("AZURE_ENDPOINT") + or os.getenv("AZURE_OPENAI_ENDPOINT") + or os.getenv("AZURE_API_BASE") + ) + + if self.api_version == "2024-06-01": + env_version = os.getenv("AZURE_API_VERSION") + if env_version: + self.api_version = env_version if not self.api_key: raise ValueError( @@ -120,36 +132,38 @@ def __init__( "Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter." ) - # Validate and potentially fix Azure OpenAI endpoint URL - self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model) + self.endpoint = self._validate_and_fix_endpoint(self.endpoint, self.model) - # Build client kwargs - client_kwargs = { + client_kwargs: dict[str, Any] = { "endpoint": self.endpoint, "credential": AzureKeyCredential(self.api_key), } - # Add api_version if specified (primarily for Azure OpenAI endpoints) if self.api_version: client_kwargs["api_version"] = self.api_version - self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type] + self.client = ChatCompletionsClient(**client_kwargs) - self.top_p = top_p - self.frequency_penalty = frequency_penalty - self.presence_penalty = presence_penalty - self.max_tokens = max_tokens - self.stream = stream - - self.is_openai_model = any( - prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"] + self._is_openai_model = any( + prefix in self.model.lower() for prefix in ["gpt-", "o1-", "text-"] ) - - self.is_azure_openai_endpoint = ( + self._is_azure_openai_endpoint = ( "openai.azure.com" in self.endpoint and "/openai/deployments/" in self.endpoint ) + return self + + @property + def is_openai_model(self) -> bool: + """Check if model is an OpenAI model.""" + return self._is_openai_model + + @property + def is_azure_openai_endpoint(self) -> bool: + """Check if endpoint is an Azure OpenAI endpoint.""" + return self._is_azure_openai_endpoint + def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str: """Validate and fix Azure endpoint URL format. diff --git a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py index f67414c639..ed6738c4b1 100644 --- a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py @@ -5,8 +5,8 @@ import os from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast -from pydantic import BaseModel, ConfigDict -from typing_extensions import Required +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from typing_extensions import Required, Self from crewai.events.types.llm_events import LLMCallType from crewai.llm.base_llm import BaseLLM @@ -32,8 +32,6 @@ ToolTypeDef, ) - from crewai.llm.hooks.base import BaseInterceptor - try: from boto3.session import Session @@ -143,76 +141,86 @@ class BedrockCompletion(BaseLLM): - Complete streaming event handling (messageStart, contentBlockStart, etc.) - Response metadata and trace information capture - Model-specific conversation format handling (e.g., Cohere requirements) + + Attributes: + model: The Bedrock model ID to use + aws_access_key_id: AWS access key (defaults to environment variable) + aws_secret_access_key: AWS secret key (defaults to environment variable) + aws_session_token: AWS session token for temporary credentials + region_name: AWS region name + max_tokens: Maximum tokens to generate + top_p: Nucleus sampling parameter + top_k: Top-k sampling parameter (Claude models only) + stop_sequences: List of sequences that stop generation + stream: Whether to use streaming responses + guardrail_config: Guardrail configuration for content filtering + additional_model_request_fields: Model-specific request parameters + additional_model_response_field_paths: Custom response field paths + interceptor: HTTP interceptor (not yet supported for Bedrock) """ - model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,)) + model_config: ClassVar[ConfigDict] = ConfigDict( + ignored_types=(property,), arbitrary_types_allowed=True + ) - def __init__( - self, - model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0", - aws_access_key_id: str | None = None, - aws_secret_access_key: str | None = None, - aws_session_token: str | None = None, - region_name: str = "us-east-1", - temperature: float | None = None, - max_tokens: int | None = None, - top_p: float | None = None, - top_k: int | None = None, - stop_sequences: Sequence[str] | None = None, - stream: bool = False, - guardrail_config: dict[str, Any] | None = None, - additional_model_request_fields: dict[str, Any] | None = None, - additional_model_response_field_paths: list[str] | None = None, - interceptor: BaseInterceptor[Any, Any] | None = None, - **kwargs: Any, - ) -> None: - """Initialize AWS Bedrock completion client. - - Args: - model: The Bedrock model ID to use - aws_access_key_id: AWS access key (defaults to environment variable) - aws_secret_access_key: AWS secret key (defaults to environment variable) - aws_session_token: AWS session token for temporary credentials - region_name: AWS region name - temperature: Sampling temperature for response generation - max_tokens: Maximum tokens to generate - top_p: Nucleus sampling parameter - top_k: Top-k sampling parameter (Claude models only) - stop_sequences: List of sequences that stop generation - stream: Whether to use streaming responses - guardrail_config: Guardrail configuration for content filtering - additional_model_request_fields: Model-specific request parameters - additional_model_response_field_paths: Custom response field paths - interceptor: HTTP interceptor (not yet supported for Bedrock). - **kwargs: Additional parameters - """ - if interceptor is not None: + aws_access_key_id: str | None = Field( + default=None, description="AWS access key (defaults to environment variable)" + ) + aws_secret_access_key: str | None = Field( + default=None, description="AWS secret key (defaults to environment variable)" + ) + aws_session_token: str | None = Field( + default=None, description="AWS session token for temporary credentials" + ) + region_name: str = Field(default="us-east-1", description="AWS region name") + max_tokens: int | None = Field( + default=None, description="Maximum tokens to generate" + ) + top_p: float | None = Field(default=None, description="Nucleus sampling parameter") + top_k: int | None = Field( + default=None, description="Top-k sampling parameter (Claude models only)" + ) + stream: bool = Field( + default=False, description="Whether to use streaming responses" + ) + guardrail_config: dict[str, Any] | None = Field( + default=None, description="Guardrail configuration for content filtering" + ) + additional_model_request_fields: dict[str, Any] | None = Field( + default=None, description="Model-specific request parameters" + ) + additional_model_response_field_paths: list[str] | None = Field( + default=None, description="Custom response field paths" + ) + interceptor: Any = Field( + default=None, description="HTTP interceptor (not yet supported for Bedrock)" + ) + client: Any = Field( + default=None, exclude=True, description="Bedrock client instance" + ) + + _is_claude_model: bool = PrivateAttr(default=False) + _supports_tools: bool = PrivateAttr(default=True) + _supports_streaming: bool = PrivateAttr(default=True) + _model_id: str = PrivateAttr() + + @model_validator(mode="after") + def setup_client(self) -> Self: + """Initialize the Bedrock client and validate configuration.""" + if self.interceptor is not None: raise NotImplementedError( "HTTP interceptors are not yet supported for AWS Bedrock provider. " "Interceptors are currently supported for OpenAI and Anthropic providers only." ) - # Extract provider from kwargs to avoid duplicate argument - kwargs.pop("provider", None) - - super().__init__( - model=model, - temperature=temperature, - stop=stop_sequences or [], - provider="bedrock", - **kwargs, - ) - - # Initialize Bedrock client with proper configuration session = Session( - aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), - aws_secret_access_key=aws_secret_access_key + aws_access_key_id=self.aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=self.aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY"), - aws_session_token=aws_session_token or os.getenv("AWS_SESSION_TOKEN"), - region_name=region_name, + aws_session_token=self.aws_session_token or os.getenv("AWS_SESSION_TOKEN"), + region_name=self.region_name, ) - # Configure client with timeouts and retries following AWS best practices config = Config( read_timeout=300, retries={ @@ -223,53 +231,33 @@ def __init__( ) self.client = session.client("bedrock-runtime", config=config) - self.region_name = region_name - - # Store completion parameters - self.max_tokens = max_tokens - self.top_p = top_p - self.top_k = top_k - self.stream = stream - self.stop_sequences = stop_sequences or [] - - # Store advanced features (optional) - self.guardrail_config = guardrail_config - self.additional_model_request_fields = additional_model_request_fields - self.additional_model_response_field_paths = ( - additional_model_response_field_paths - ) - # Model-specific settings - self.is_claude_model = "claude" in model.lower() - self.supports_tools = True # Converse API supports tools for most models - self.supports_streaming = True - - # Handle inference profiles for newer models - self.model_id = model - - # @property - # def stop(self) -> list[str]: # type: ignore[misc] - # """Get stop sequences sent to the API.""" - # return list(self.stop_sequences) - - # @stop.setter - # def stop(self, value: Sequence[str] | str | None) -> None: - # """Set stop sequences. - # - # Synchronizes stop_sequences to ensure values set by CrewAgentExecutor - # are properly sent to the Bedrock API. - # - # Args: - # value: Stop sequences as a Sequence, single string, or None - # """ - # if value is None: - # self.stop_sequences = [] - # elif isinstance(value, str): - # self.stop_sequences = [value] - # elif isinstance(value, Sequence): - # self.stop_sequences = list(value) - # else: - # self.stop_sequences = [] + self._is_claude_model = "claude" in self.model.lower() + self._supports_tools = True + self._supports_streaming = True + self._model_id = self.model + + return self + + @property + def is_claude_model(self) -> bool: + """Check if model is a Claude model.""" + return self._is_claude_model + + @property + def supports_tools(self) -> bool: + """Check if model supports tools.""" + return self._supports_tools + + @property + def supports_streaming(self) -> bool: + """Check if model supports streaming.""" + return self._supports_streaming + + @property + def model_id(self) -> str: + """Get the model ID.""" + return self._model_id def call( self, @@ -559,7 +547,7 @@ def _handle_streaming_converse( "Sequence[MessageTypeDef | MessageOutputTypeDef]", cast(object, messages), ), - **body, # type: ignore[arg-type] + **body, ) stream = response.get("stream") @@ -821,8 +809,8 @@ def _get_inference_config(self) -> EnhancedInferenceConfigurationTypeDef: config["temperature"] = float(self.temperature) if self.top_p is not None: config["topP"] = float(self.top_p) - if self.stop_sequences: - config["stopSequences"] = self.stop_sequences + if self.stop: + config["stopSequences"] = self.stop if self.is_claude_model and self.top_k is not None: # top_k is supported by Claude models diff --git a/lib/crewai/src/crewai/llm/providers/gemini/completion.py b/lib/crewai/src/crewai/llm/providers/gemini/completion.py index 263309910f..34bff25083 100644 --- a/lib/crewai/src/crewai/llm/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llm/providers/gemini/completion.py @@ -2,12 +2,12 @@ import os from typing import Any, ClassVar, cast -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType from crewai.llm.base_llm import BaseLLM from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES -from crewai.llm.hooks.base import BaseInterceptor from crewai.llm.providers.utils.common import safe_tool_conversion from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( @@ -31,108 +31,124 @@ class GeminiCompletion(BaseLLM): This class provides direct integration with the Google Gen AI Python SDK, offering native function calling, streaming support, and proper Gemini formatting. + + Attributes: + model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro') + project: Google Cloud project ID (for Vertex AI) + location: Google Cloud location (for Vertex AI, defaults to 'us-central1') + top_p: Nucleus sampling parameter + top_k: Top-k sampling parameter + max_output_tokens: Maximum tokens in response + stop_sequences: Stop sequences + stream: Enable streaming responses + safety_settings: Safety filter settings + client_params: Additional parameters for Google Gen AI Client constructor + interceptor: HTTP interceptor (not yet supported for Gemini) """ - model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,)) + model_config: ClassVar[ConfigDict] = ConfigDict( + ignored_types=(property,), arbitrary_types_allowed=True + ) + + project: str | None = Field( + default=None, description="Google Cloud project ID (for Vertex AI)" + ) + location: str = Field( + default="us-central1", + description="Google Cloud location (for Vertex AI, defaults to 'us-central1')", + ) + top_p: float | None = Field(default=None, description="Nucleus sampling parameter") + top_k: int | None = Field(default=None, description="Top-k sampling parameter") + max_output_tokens: int | None = Field( + default=None, description="Maximum tokens in response" + ) + stream: bool = Field(default=False, description="Enable streaming responses") + safety_settings: dict[str, Any] = Field( + default_factory=dict, description="Safety filter settings" + ) + client_params: dict[str, Any] = Field( + default_factory=dict, + description="Additional parameters for Google Gen AI Client constructor", + ) + interceptor: Any = Field( + default=None, description="HTTP interceptor (not yet supported for Gemini)" + ) + client: Any = Field( + default=None, exclude=True, description="Gemini client instance" + ) + + _is_gemini_2: bool = PrivateAttr(default=False) + _is_gemini_1_5: bool = PrivateAttr(default=False) + _supports_tools: bool = PrivateAttr(default=False) + + @property + def stop_sequences(self) -> list[str]: + """Get stop sequences as a list. + + This property provides access to stop sequences in Gemini's native format + while maintaining synchronization with the base class's stop attribute. + """ + if self.stop is None: + return [] + if isinstance(self.stop, str): + return [self.stop] + return self.stop - def __init__( - self, - model: str = "gemini-2.0-flash-001", - api_key: str | None = None, - project: str | None = None, - location: str | None = None, - temperature: float | None = None, - top_p: float | None = None, - top_k: int | None = None, - max_output_tokens: int | None = None, - stop_sequences: list[str] | None = None, - stream: bool = False, - safety_settings: dict[str, Any] | None = None, - client_params: dict[str, Any] | None = None, - interceptor: BaseInterceptor[Any, Any] | None = None, - **kwargs: Any, - ): - """Initialize Google Gemini chat completion client. + @stop_sequences.setter + def stop_sequences(self, value: list[str] | str | None) -> None: + """Set stop sequences, synchronizing with the stop attribute. Args: - model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro') - api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var) - project: Google Cloud project ID (for Vertex AI) - location: Google Cloud location (for Vertex AI, defaults to 'us-central1') - temperature: Sampling temperature (0-2) - top_p: Nucleus sampling parameter - top_k: Top-k sampling parameter - max_output_tokens: Maximum tokens in response - stop_sequences: Stop sequences - stream: Enable streaming responses - safety_settings: Safety filter settings - client_params: Additional parameters to pass to the Google Gen AI Client constructor. - Supports parameters like http_options, credentials, debug_config, etc. - interceptor: HTTP interceptor (not yet supported for Gemini). - **kwargs: Additional parameters + value: Stop sequences as a list, string, or None """ - if interceptor is not None: + self.stop = value + + @model_validator(mode="after") + def setup_client(self) -> Self: + """Initialize the Gemini client and validate configuration.""" + if self.interceptor is not None: raise NotImplementedError( "HTTP interceptors are not yet supported for Google Gemini provider. " "Interceptors are currently supported for OpenAI and Anthropic providers only." ) - super().__init__( - model=model, temperature=temperature, stop=stop_sequences or [], **kwargs - ) + if self.api_key is None: + self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") - # Store client params for later use - self.client_params = client_params or {} + if self.project is None: + self.project = os.getenv("GOOGLE_CLOUD_PROJECT") - # Get API configuration with environment variable fallbacks - self.api_key = ( - api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") - ) - self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT") - self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1" + if self.location == "us-central1": + env_location = os.getenv("GOOGLE_CLOUD_LOCATION") + if env_location: + self.location = env_location use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true" self.client = self._initialize_client(use_vertexai) - # Store completion parameters - self.top_p = top_p - self.top_k = top_k - self.max_output_tokens = max_output_tokens - self.stream = stream - self.safety_settings = safety_settings or {} - self.stop_sequences = stop_sequences or [] - - # Model-specific settings - self.is_gemini_2 = "gemini-2" in model.lower() - self.is_gemini_1_5 = "gemini-1.5" in model.lower() - self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2 - - # @property - # def stop(self) -> list[str]: # type: ignore[misc] - # """Get stop sequences sent to the API.""" - # return self.stop_sequences - - # @stop.setter - # def stop(self, value: list[str] | str | None) -> None: - # """Set stop sequences. - # - # Synchronizes stop_sequences to ensure values set by CrewAgentExecutor - # are properly sent to the Gemini API. - # - # Args: - # value: Stop sequences as a list, single string, or None - # """ - # if value is None: - # self.stop_sequences = [] - # elif isinstance(value, str): - # self.stop_sequences = [value] - # elif isinstance(value, list): - # self.stop_sequences = value - # else: - # self.stop_sequences = [] - - def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # type: ignore[no-any-unimported] + self._is_gemini_2 = "gemini-2" in self.model.lower() + self._is_gemini_1_5 = "gemini-1.5" in self.model.lower() + self._supports_tools = self._is_gemini_1_5 or self._is_gemini_2 + + return self + + @property + def is_gemini_2(self) -> bool: + """Check if model is Gemini 2.""" + return self._is_gemini_2 + + @property + def is_gemini_1_5(self) -> bool: + """Check if model is Gemini 1.5.""" + return self._is_gemini_1_5 + + @property + def supports_tools(self) -> bool: + """Check if model supports tools.""" + return self._supports_tools + + def _initialize_client(self, use_vertexai: bool = False) -> Any: """Initialize the Google Gen AI client with proper parameter handling. Args: @@ -154,12 +170,9 @@ def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # typ "location": self.location, } ) - client_params.pop("api_key", None) - elif self.api_key: client_params["api_key"] = self.api_key - client_params.pop("vertexai", None) client_params.pop("project", None) client_params.pop("location", None) @@ -188,7 +201,6 @@ def _get_client_params(self) -> dict[str, Any]: and hasattr(self.client, "vertexai") and self.client.vertexai ): - # Vertex AI configuration params.update( { "vertexai": True, @@ -300,15 +312,12 @@ def _prepare_generation_config( # type: ignore[no-any-unimported] self.tools = tools config_params = {} - # Add system instruction if present if system_instruction: - # Convert system instruction to Content format system_content = types.Content( role="user", parts=[types.Part.from_text(text=system_instruction)] ) config_params["system_instruction"] = system_content - # Add generation config parameters if self.temperature is not None: config_params["temperature"] = self.temperature if self.top_p is not None: @@ -317,14 +326,13 @@ def _prepare_generation_config( # type: ignore[no-any-unimported] config_params["top_k"] = self.top_k if self.max_output_tokens is not None: config_params["max_output_tokens"] = self.max_output_tokens - if self.stop_sequences: - config_params["stop_sequences"] = self.stop_sequences + if self.stop: + config_params["stop_sequences"] = self.stop if response_model: config_params["response_mime_type"] = "application/json" config_params["response_schema"] = response_model.model_json_schema() - # Handle tools for supported models if tools and self.supports_tools: config_params["tools"] = self._convert_tools_for_interference(tools) @@ -347,7 +355,6 @@ def _convert_tools_for_interference( # type: ignore[no-any-unimported] description=description, ) - # Add parameters if present - ensure parameters is a dict if parameters and isinstance(parameters, dict): function_declaration.parameters = parameters @@ -383,16 +390,12 @@ def _format_messages_for_gemini( # type: ignore[no-any-unimported] content = message.get("content", "") if role == "system": - # Extract system instruction - Gemini handles it separately if system_instruction: system_instruction += f"\n\n{content}" else: system_instruction = cast(str, content) else: - # Convert role for Gemini (assistant -> model) gemini_role = "model" if role == "assistant" else "user" - - # Create Content object gemini_content = types.Content( role=gemini_role, parts=[types.Part.from_text(text=content)] ) @@ -509,13 +512,11 @@ def _handle_streaming_completion( # type: ignore[no-any-unimported] else {}, } - # Handle completed function calls if function_calls and available_functions: for call_data in function_calls.values(): function_name = call_data["name"] function_args = call_data["args"] - # Execute tool result = self._handle_tool_execution( function_name=function_name, function_args=function_args, @@ -575,13 +576,11 @@ def get_context_window_size(self) -> int: "gemma-3-27b": 128000, } - # Find the best match for the model name for model_prefix, size in context_windows.items(): if self.model.startswith(model_prefix): return int(size * CONTEXT_WINDOW_USAGE_RATIO) - # Default context window size for Gemini models - return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens + return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]: """Extract token usage from Gemini response.""" diff --git a/lib/crewai/src/crewai/llm/providers/openai/completion.py b/lib/crewai/src/crewai/llm/providers/openai/completion.py index f9f65c8b13..b9fcc99c7c 100644 --- a/lib/crewai/src/crewai/llm/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llm/providers/openai/completion.py @@ -11,7 +11,8 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import ChoiceDelta -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType from crewai.llm.base_llm import BaseLLM @@ -73,26 +74,18 @@ class OpenAICompletion(BaseLLM): ) reasoning_effort: str | None = Field(None, description="Reasoning effort level") - # Internal state client: OpenAI = Field( default_factory=OpenAI, exclude=True, description="OpenAI client instance" ) is_o1_model: bool = Field(False, description="Whether this is an O1 model") is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model") - def model_post_init(self, __context: Any) -> None: - """Initialize OpenAI client after model initialization. - - Args: - __context: Pydantic context - """ - super().model_post_init(__context) - - # Set API key from environment if not provided + @model_validator(mode="after") + def setup_client(self) -> Self: + """Initialize OpenAI client after model validation.""" if self.api_key is None: self.api_key = os.getenv("OPENAI_API_KEY") - # Initialize client client_config = self._get_client_params() if self.interceptor: transport = HTTPTransport(interceptor=self.interceptor) @@ -101,10 +94,11 @@ def model_post_init(self, __context: Any) -> None: self.client = OpenAI(**client_config) - # Set model flags self.is_o1_model = "o1" in self.model.lower() self.is_gpt4_model = "gpt-4" in self.model.lower() + return self + def _get_client_params(self) -> dict[str, Any]: """Get OpenAI client parameters.""" diff --git a/lib/crewai/src/crewai/tasks/hallucination_guardrail.py b/lib/crewai/src/crewai/tasks/hallucination_guardrail.py index dd000a83cd..07b4653663 100644 --- a/lib/crewai/src/crewai/tasks/hallucination_guardrail.py +++ b/lib/crewai/src/crewai/tasks/hallucination_guardrail.py @@ -8,7 +8,7 @@ from typing import Any -from crewai.llm import LLM +from crewai.llm.core import LLM from crewai.tasks.task_output import TaskOutput from crewai.utilities.logger import Logger diff --git a/lib/crewai/src/crewai/tools/tool_usage.py b/lib/crewai/src/crewai/tools/tool_usage.py index 6f0e92cb8d..69791291a1 100644 --- a/lib/crewai/src/crewai/tools/tool_usage.py +++ b/lib/crewai/src/crewai/tools/tool_usage.py @@ -36,7 +36,7 @@ from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.tools_handler import ToolsHandler from crewai.lite_agent import LiteAgent - from crewai.llm import LLM + from crewai.llm.core import LLM from crewai.task import Task diff --git a/lib/crewai/tests/test_llm.py b/lib/crewai/tests/test_llm.py index 3d8a1282e2..16175bf0fc 100644 --- a/lib/crewai/tests/test_llm.py +++ b/lib/crewai/tests/test_llm.py @@ -11,7 +11,7 @@ ToolUsageFinishedEvent, ToolUsageStartedEvent, ) -from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM +from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM from crewai.utilities.token_counter_callback import TokenCalcHandler from pydantic import BaseModel import pytest @@ -229,7 +229,7 @@ class DummyResponse(BaseModel): a: int # Patch supports_response_schema to simulate a supported model. - with patch("crewai.llm.supports_response_schema", return_value=True): + with patch("crewai.llm.core.supports_response_schema", return_value=True): llm = LLM( model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse ) @@ -242,7 +242,7 @@ class DummyResponse(BaseModel): a: int # Patch supports_response_schema to simulate an unsupported model. - with patch("crewai.llm.supports_response_schema", return_value=False): + with patch("crewai.llm.core.supports_response_schema", return_value=False): llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True) with pytest.raises(ValueError) as excinfo: llm._validate_call_params() @@ -342,7 +342,7 @@ def test_context_window_validation(): # Test invalid window size with pytest.raises(ValueError) as excinfo: with patch.dict( - "crewai.llm.LLM_CONTEXT_WINDOW_SIZES", + "crewai.llm.core.LLM_CONTEXT_WINDOW_SIZES", {"test-model": 500}, # Below minimum clear=True, ): @@ -702,8 +702,8 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm): def test_native_provider_raises_error_when_supported_but_fails(): """Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error.""" - with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]): - with patch("crewai.llm.LLM._get_native_provider") as mock_get_native: + with patch("crewai.llm.internal.meta.SUPPORTED_NATIVE_PROVIDERS", ["openai"]): + with patch("crewai.llm.internal.meta.LLMMeta._get_native_provider") as mock_get_native: # Mock that provider exists but throws an error when instantiated mock_provider = MagicMock() mock_provider.side_effect = ValueError("Native provider initialization failed") @@ -718,7 +718,7 @@ def test_native_provider_raises_error_when_supported_but_fails(): def test_native_provider_falls_back_to_litellm_when_not_in_supported_list(): """Test that when a provider is not in SUPPORTED_NATIVE_PROVIDERS, we fall back to LiteLLM.""" - with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai", "anthropic"]): + with patch("crewai.llm.internal.meta.SUPPORTED_NATIVE_PROVIDERS", ["openai", "anthropic"]): # Using a provider not in the supported list llm = LLM(model="groq/llama-3.1-70b-versatile", is_litellm=False) diff --git a/logs.txt b/logs.txt deleted file mode 100644 index 6aaa37c3a2..0000000000 --- a/logs.txt +++ /dev/null @@ -1,20 +0,0 @@ -lib/crewai/src/crewai/agent/core.py:901: error: Argument 1 has incompatible type "ToolFilterContext"; expected "dict[str, Any]" [arg-type] -lib/crewai/src/crewai/agent/core.py:901: note: Error code "arg-type" not covered by "type: ignore" comment -lib/crewai/src/crewai/agent/core.py:905: error: Argument 1 has incompatible type "dict[str, Any]"; expected "ToolFilterContext" [arg-type] -lib/crewai/src/crewai/agent/core.py:905: note: Error code "arg-type" not covered by "type: ignore" comment -lib/crewai/src/crewai/agent/core.py:996: error: Returning Any from function declared to return "dict[str, dict[str, Any]]" [no-any-return] -lib/crewai/src/crewai/agent/core.py:1157: error: Incompatible types in assignment (expression has type "tuple[UnionType, None]", target has type "tuple[type, Any]") [assignment] -lib/crewai/src/crewai/agent/core.py:1183: error: Argument 1 to "append" of "list" has incompatible type "type"; expected "type[str]" [arg-type] -lib/crewai/src/crewai/agent/core.py:1188: error: Incompatible types in assignment (expression has type "UnionType", variable has type "type[str]") [assignment] -lib/crewai/src/crewai/agent/core.py:1201: error: Argument 1 to "get" of "dict" has incompatible type "Any | None"; expected "str" [arg-type] -Found 7 errors in 1 file (checked 4 source files) -Success: no issues found in 4 source files -lib/crewai/src/crewai/llm/providers/gemini/completion.py:111: error: BaseModel field may only be overridden by another field [misc] -Found 1 error in 1 file (checked 4 source files) -Success: no issues found in 4 source files -lib/crewai/src/crewai/llm/providers/anthropic/completion.py:101: error: BaseModel field may only be overridden by another field [misc] -Found 1 error in 1 file (checked 4 source files) -lib/crewai/src/crewai/llm/providers/bedrock/completion.py:250: error: BaseModel field may only be overridden by another field [misc] -Found 1 error in 1 file (checked 4 source files) - -uv-lock..............................................(no files to check)Skipped From 722d316824987bcc05908d54be6ec862c009c47d Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Mon, 10 Nov 2025 16:05:23 -0500 Subject: [PATCH 3/7] chore: continue refactoring llms to base models --- lib/crewai/src/crewai/__init__.py | 2 +- lib/crewai/src/crewai/cli/crew_chat.py | 3 +- lib/crewai/src/crewai/crew.py | 2 +- .../src/crewai/events/event_listener.py | 2 +- .../experimental/evaluation/base_evaluator.py | 2 +- lib/crewai/src/crewai/lite_agent.py | 2 +- lib/crewai/src/crewai/llm/base_llm.py | 16 +- lib/crewai/src/crewai/llm/internal/meta.py | 34 ++- .../llm/providers/anthropic/completion.py | 145 +++++------- .../crewai/llm/providers/azure/completion.py | 152 +++++++------ .../llm/providers/bedrock/completion.py | 210 ++++++++--------- .../crewai/llm/providers/gemini/completion.py | 211 +++++++++--------- .../crewai/llm/providers/openai/completion.py | 20 +- .../crewai/tasks/hallucination_guardrail.py | 2 +- lib/crewai/src/crewai/tools/tool_usage.py | 2 +- lib/crewai/tests/test_llm.py | 14 +- logs.txt | 20 -- 17 files changed, 395 insertions(+), 444 deletions(-) delete mode 100644 logs.txt diff --git a/lib/crewai/src/crewai/__init__.py b/lib/crewai/src/crewai/__init__.py index ef2bcf78d2..4e2365a2ff 100644 --- a/lib/crewai/src/crewai/__init__.py +++ b/lib/crewai/src/crewai/__init__.py @@ -8,8 +8,8 @@ from crewai.crews.crew_output import CrewOutput from crewai.flow.flow import Flow from crewai.knowledge.knowledge import Knowledge -from crewai.llm import LLM from crewai.llm.base_llm import BaseLLM +from crewai.llm.core import LLM from crewai.process import Process from crewai.task import Task from crewai.tasks.llm_guardrail import LLMGuardrail diff --git a/lib/crewai/src/crewai/cli/crew_chat.py b/lib/crewai/src/crewai/cli/crew_chat.py index feca9e4ca3..f593180864 100644 --- a/lib/crewai/src/crewai/cli/crew_chat.py +++ b/lib/crewai/src/crewai/cli/crew_chat.py @@ -14,7 +14,8 @@ from crewai.cli.utils import read_toml from crewai.cli.version import get_crewai_version from crewai.crew import Crew -from crewai.llm import LLM, BaseLLM +from crewai.llm import LLM +from crewai.llm.base_llm import BaseLLM from crewai.types.crew_chat import ChatInputField, ChatInputs from crewai.utilities.llm_utils import create_llm from crewai.utilities.printer import Printer diff --git a/lib/crewai/src/crewai/crew.py b/lib/crewai/src/crewai/crew.py index 31eb7466ca..b258f3eaa6 100644 --- a/lib/crewai/src/crewai/crew.py +++ b/lib/crewai/src/crewai/crew.py @@ -56,8 +56,8 @@ from crewai.flow.flow_trackable import FlowTrackable from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource -from crewai.llm import LLM from crewai.llm.base_llm import BaseLLM +from crewai.llm.core import LLM from crewai.memory.entity.entity_memory import EntityMemory from crewai.memory.external.external_memory import ExternalMemory from crewai.memory.long_term.long_term_memory import LongTermMemory diff --git a/lib/crewai/src/crewai/events/event_listener.py b/lib/crewai/src/crewai/events/event_listener.py index e07ee193ce..0a88e85f03 100644 --- a/lib/crewai/src/crewai/events/event_listener.py +++ b/lib/crewai/src/crewai/events/event_listener.py @@ -89,7 +89,7 @@ ToolUsageStartedEvent, ) from crewai.events.utils.console_formatter import ConsoleFormatter -from crewai.llm import LLM +from crewai.llm.core import LLM from crewai.task import Task from crewai.telemetry.telemetry import Telemetry from crewai.utilities import Logger diff --git a/lib/crewai/src/crewai/experimental/evaluation/base_evaluator.py b/lib/crewai/src/crewai/experimental/evaluation/base_evaluator.py index 69d1bb5c3b..8001074d33 100644 --- a/lib/crewai/src/crewai/experimental/evaluation/base_evaluator.py +++ b/lib/crewai/src/crewai/experimental/evaluation/base_evaluator.py @@ -7,7 +7,7 @@ from crewai.agent import Agent from crewai.agents.agent_builder.base_agent import BaseAgent -from crewai.llm import BaseLLM +from crewai.llm.base_llm import BaseLLM from crewai.task import Task from crewai.utilities.llm_utils import create_llm diff --git a/lib/crewai/src/crewai/lite_agent.py b/lib/crewai/src/crewai/lite_agent.py index ef877e01b6..e91e6b98f4 100644 --- a/lib/crewai/src/crewai/lite_agent.py +++ b/lib/crewai/src/crewai/lite_agent.py @@ -39,8 +39,8 @@ from crewai.events.types.logging_events import AgentLogsExecutionEvent from crewai.flow.flow_trackable import FlowTrackable from crewai.lite_agent_output import LiteAgentOutput -from crewai.llm import LLM from crewai.llm.base_llm import BaseLLM +from crewai.llm.core import LLM from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool from crewai.utilities.agent_utils import ( diff --git a/lib/crewai/src/crewai/llm/base_llm.py b/lib/crewai/src/crewai/llm/base_llm.py index 1222763457..f60ce500e6 100644 --- a/lib/crewai/src/crewai/llm/base_llm.py +++ b/lib/crewai/src/crewai/llm/base_llm.py @@ -66,7 +66,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): """ model_config: ClassVar[ConfigDict] = ConfigDict( - arbitrary_types_allowed=True, extra="allow", validate_assignment=True + arbitrary_types_allowed=True, extra="allow" ) # Core fields @@ -80,7 +80,9 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): default="openai", description="Provider name (openai, anthropic, etc.)" ) stop: list[str] = Field( - default_factory=list, description="Stop sequences for generation" + default_factory=list, + description="Stop sequences for generation", + validation_alias="stop_sequences", ) # Internal fields @@ -112,16 +114,18 @@ def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]: if not values.get("model"): raise ValueError("Model name is required and cannot be empty") - # Handle stop sequences - stop = values.get("stop") + stop = values.get("stop") or values.get("stop_sequences") if stop is None: values["stop"] = [] elif isinstance(stop, str): values["stop"] = [stop] - elif not isinstance(stop, list): + elif isinstance(stop, list): + values["stop"] = stop + else: values["stop"] = [] - # Set default provider if not specified + values.pop("stop_sequences", None) + if "provider" not in values or values["provider"] is None: values["provider"] = "openai" diff --git a/lib/crewai/src/crewai/llm/internal/meta.py b/lib/crewai/src/crewai/llm/internal/meta.py index 4bf83b655a..97977ad558 100644 --- a/lib/crewai/src/crewai/llm/internal/meta.py +++ b/lib/crewai/src/crewai/llm/internal/meta.py @@ -33,13 +33,12 @@ class LLMMeta(ModelMetaclass): native provider implementation based on the model parameter. """ - def __call__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> Any: # noqa: N805 + def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805 """Route to appropriate provider implementation at instantiation time. Args: - model: The model identifier (e.g., "gpt-4", "claude-3-opus") - is_litellm: Force use of LiteLLM instead of native provider - **kwargs: Additional parameters for the LLM + *args: Positional arguments (model should be first for LLM class) + **kwargs: Keyword arguments including model, is_litellm, etc. Returns: Instance of the appropriate provider class or LLM class @@ -47,18 +46,18 @@ def __call__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> Any: Raises: ValueError: If model is not a valid string """ - if not model or not isinstance(model, str): - raise ValueError("Model must be a non-empty string") + if cls.__name__ != "LLM": + return super().__call__(*args, **kwargs) - # Only perform routing if called on the base LLM class - # Subclasses (OpenAICompletion, etc.) should create normally - from crewai.llm import LLM + model = kwargs.get("model") or (args[0] if args else None) + is_litellm = kwargs.get("is_litellm", False) - if cls is not LLM: - # Direct instantiation of provider class, skip routing - return super().__call__(model=model, **kwargs) + if not model or not isinstance(model, str): + raise ValueError("Model must be a non-empty string") - # Extract provider information + if args and not kwargs.get("model"): + kwargs["model"] = args[0] + args = args[1:] explicit_provider = kwargs.get("provider") if explicit_provider: @@ -97,12 +96,10 @@ def __call__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> Any: use_native = True model_string = model - # Route to native provider if available native_class = cls._get_native_provider(provider) if use_native else None if native_class and not is_litellm and provider in SUPPORTED_NATIVE_PROVIDERS: try: - # Remove 'provider' from kwargs to avoid duplicate keyword argument - kwargs_copy = {k: v for k, v in kwargs.items() if k != "provider"} + kwargs_copy = {k: v for k, v in kwargs.items() if k not in ("provider", "model")} return native_class( model=model_string, provider=provider, **kwargs_copy ) @@ -111,15 +108,14 @@ def __call__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> Any: except Exception as e: raise ImportError(f"Error importing native provider: {e}") from e - # Fallback to LiteLLM try: import litellm # noqa: F401 except ImportError: logging.error("LiteLLM is not available, falling back to LiteLLM") raise ImportError("Fallback to LiteLLM is not available") from None - # Create actual LLM instance with is_litellm=True - return super().__call__(model=model, is_litellm=True, **kwargs) + kwargs_copy = {k: v for k, v in kwargs.items() if k not in ("model", "is_litellm")} + return super().__call__(model=model, is_litellm=True, **kwargs_copy) @staticmethod def _validate_model_in_constants(model: str, provider: str) -> bool: diff --git a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py index 8ebba16737..dd13c0f5e1 100644 --- a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py @@ -3,9 +3,10 @@ import json import logging import os -from typing import TYPE_CHECKING, Any, ClassVar, cast +from typing import Any, ClassVar, cast -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType from crewai.llm.base_llm import BaseLLM @@ -19,9 +20,6 @@ from crewai.utilities.types import LLMMessage -if TYPE_CHECKING: - from crewai.llm.hooks.base import BaseInterceptor - try: from anthropic import Anthropic from anthropic.types import Message @@ -38,90 +36,67 @@ class AnthropicCompletion(BaseLLM): This class provides direct integration with the Anthropic Python SDK, offering native tool use, streaming support, and proper message formatting. + + Attributes: + model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022') + base_url: Custom base URL for Anthropic API + timeout: Request timeout in seconds + max_retries: Maximum number of retries + max_tokens: Maximum tokens in response (required for Anthropic) + top_p: Nucleus sampling parameter + stream: Enable streaming responses + client_params: Additional parameters for the Anthropic client + interceptor: HTTP interceptor for modifying requests/responses at transport level """ - model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,)) + model_config: ClassVar[ConfigDict] = ConfigDict( + ignored_types=(property,), arbitrary_types_allowed=True + ) + + base_url: str | None = Field( + default=None, description="Custom base URL for Anthropic API" + ) + timeout: float | None = Field( + default=None, description="Request timeout in seconds" + ) + max_retries: int = Field(default=2, description="Maximum number of retries") + max_tokens: int = Field( + default=4096, description="Maximum tokens in response (required for Anthropic)" + ) + top_p: float | None = Field(default=None, description="Nucleus sampling parameter") + stream: bool = Field(default=False, description="Enable streaming responses") + client_params: dict[str, Any] | None = Field( + default=None, description="Additional Anthropic client parameters" + ) + interceptor: Any = Field( + default=None, description="HTTP interceptor for request/response modification" + ) + client: Any = Field( + default=None, exclude=True, description="Anthropic client instance" + ) + + _is_claude_3: bool = PrivateAttr(default=False) + _supports_tools: bool = PrivateAttr(default=False) + + @model_validator(mode="after") + def setup_client(self) -> Self: + """Initialize the Anthropic client and model-specific settings.""" + self.client = Anthropic(**self._get_client_params()) - def __init__( - self, - model: str = "claude-3-5-sonnet-20241022", - api_key: str | None = None, - base_url: str | None = None, - timeout: float | None = None, - max_retries: int = 2, - temperature: float | None = None, - max_tokens: int = 4096, # Required for Anthropic - top_p: float | None = None, - stop_sequences: list[str] | None = None, - stream: bool = False, - client_params: dict[str, Any] | None = None, - interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None, - **kwargs: Any, - ): - """Initialize Anthropic chat completion client. + self._is_claude_3 = "claude-3" in self.model.lower() + self._supports_tools = self._is_claude_3 - Args: - model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022') - api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var) - base_url: Custom base URL for Anthropic API - timeout: Request timeout in seconds - max_retries: Maximum number of retries - temperature: Sampling temperature (0-1) - max_tokens: Maximum tokens in response (required for Anthropic) - top_p: Nucleus sampling parameter - stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop) - stream: Enable streaming responses - client_params: Additional parameters for the Anthropic client - interceptor: HTTP interceptor for modifying requests/responses at transport level. - **kwargs: Additional parameters - """ - super().__init__( - model=model, temperature=temperature, stop=stop_sequences or [], **kwargs - ) + return self - # Client params - self.interceptor = interceptor - self.client_params = client_params - self.base_url = base_url - self.timeout = timeout - self.max_retries = max_retries - - self.client = Anthropic(**self._get_client_params()) + @property + def is_claude_3(self) -> bool: + """Check if model is Claude 3.""" + return self._is_claude_3 - # Store completion parameters - self.max_tokens = max_tokens - self.top_p = top_p - self.stream = stream - self.stop_sequences = stop_sequences or [] - - # Model-specific settings - self.is_claude_3 = "claude-3" in model.lower() - self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use - - # - # @property - # def stop(self) -> list[str]: # type: ignore[misc] - # """Get stop sequences sent to the API.""" - # return self.stop_sequences - - # @stop.setter - # def stop(self, value: list[str] | str | None) -> None: - # """Set stop sequences. - # - # Synchronizes stop_sequences to ensure values set by CrewAgentExecutor - # are properly sent to the Anthropic API. - # - # Args: - # value: Stop sequences as a list, single string, or None - # """ - # if value is None: - # self.stop_sequences = [] - # elif isinstance(value, str): - # self.stop_sequences = [value] - # elif isinstance(value, list): - # self.stop_sequences = value - # else: - # self.stop_sequences = [] + @property + def supports_tools(self) -> bool: + """Check if model supports tools.""" + return self._supports_tools def _get_client_params(self) -> dict[str, Any]: """Get client parameters.""" @@ -250,8 +225,8 @@ def _prepare_completion_params( params["temperature"] = self.temperature if self.top_p is not None: params["top_p"] = self.top_p - if self.stop_sequences: - params["stop_sequences"] = self.stop_sequences + if self.stop: + params["stop_sequences"] = self.stop # Handle tools for Claude 3+ if tools and self.supports_tools: diff --git a/lib/crewai/src/crewai/llm/providers/azure/completion.py b/lib/crewai/src/crewai/llm/providers/azure/completion.py index a389c18255..6076b9f4d1 100644 --- a/lib/crewai/src/crewai/llm/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llm/providers/azure/completion.py @@ -3,9 +3,10 @@ import json import logging import os -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from typing_extensions import Self from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES from crewai.llm.providers.utils.common import safe_tool_conversion @@ -17,7 +18,6 @@ if TYPE_CHECKING: - from crewai.llm.hooks.base import BaseInterceptor from crewai.tools.base_tool import BaseTool @@ -51,65 +51,77 @@ class AzureCompletion(BaseLLM): This class provides direct integration with the Azure AI Inference Python SDK, offering native function calling, streaming support, and proper Azure authentication. + + Attributes: + model: Azure deployment name or model name + endpoint: Azure endpoint URL + api_version: Azure API version + timeout: Request timeout in seconds + max_retries: Maximum number of retries + top_p: Nucleus sampling parameter + frequency_penalty: Frequency penalty (-2 to 2) + presence_penalty: Presence penalty (-2 to 2) + max_tokens: Maximum tokens in response + stream: Enable streaming responses + interceptor: HTTP interceptor (not yet supported for Azure) """ - def __init__( - self, - model: str, - api_key: str | None = None, - endpoint: str | None = None, - api_version: str | None = None, - timeout: float | None = None, - max_retries: int = 2, - temperature: float | None = None, - top_p: float | None = None, - frequency_penalty: float | None = None, - presence_penalty: float | None = None, - max_tokens: int | None = None, - stop: list[str] | None = None, - stream: bool = False, - interceptor: BaseInterceptor[Any, Any] | None = None, - **kwargs: Any, - ): - """Initialize Azure AI Inference chat completion client. + model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) - Args: - model: Azure deployment name or model name - api_key: Azure API key (defaults to AZURE_API_KEY env var) - endpoint: Azure endpoint URL (defaults to AZURE_ENDPOINT env var) - api_version: Azure API version (defaults to AZURE_API_VERSION env var) - timeout: Request timeout in seconds - max_retries: Maximum number of retries - temperature: Sampling temperature (0-2) - top_p: Nucleus sampling parameter - frequency_penalty: Frequency penalty (-2 to 2) - presence_penalty: Presence penalty (-2 to 2) - max_tokens: Maximum tokens in response - stop: Stop sequences - stream: Enable streaming responses - interceptor: HTTP interceptor (not yet supported for Azure). - **kwargs: Additional parameters - """ - if interceptor is not None: + endpoint: str | None = Field( + default=None, + description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)", + ) + api_version: str = Field( + default="2024-06-01", + description="Azure API version (defaults to AZURE_API_VERSION env var or 2024-06-01)", + ) + timeout: float | None = Field( + default=None, description="Request timeout in seconds" + ) + max_retries: int = Field(default=2, description="Maximum number of retries") + top_p: float | None = Field(default=None, description="Nucleus sampling parameter") + frequency_penalty: float | None = Field( + default=None, description="Frequency penalty (-2 to 2)" + ) + presence_penalty: float | None = Field( + default=None, description="Presence penalty (-2 to 2)" + ) + max_tokens: int | None = Field( + default=None, description="Maximum tokens in response" + ) + stream: bool = Field(default=False, description="Enable streaming responses") + interceptor: Any = Field( + default=None, description="HTTP interceptor (not yet supported for Azure)" + ) + client: Any = Field(default=None, exclude=True, description="Azure client instance") + + _is_openai_model: bool = PrivateAttr(default=False) + _is_azure_openai_endpoint: bool = PrivateAttr(default=False) + + @model_validator(mode="after") + def setup_client(self) -> Self: + """Initialize the Azure client and validate configuration.""" + if self.interceptor is not None: raise NotImplementedError( "HTTP interceptors are not yet supported for Azure AI Inference provider. " "Interceptors are currently supported for OpenAI and Anthropic providers only." ) - super().__init__( - model=model, temperature=temperature, stop=stop or [], **kwargs - ) + if self.api_key is None: + self.api_key = os.getenv("AZURE_API_KEY") - self.api_key = api_key or os.getenv("AZURE_API_KEY") - self.endpoint = ( - endpoint - or os.getenv("AZURE_ENDPOINT") - or os.getenv("AZURE_OPENAI_ENDPOINT") - or os.getenv("AZURE_API_BASE") - ) - self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01" - self.timeout = timeout - self.max_retries = max_retries + if self.endpoint is None: + self.endpoint = ( + os.getenv("AZURE_ENDPOINT") + or os.getenv("AZURE_OPENAI_ENDPOINT") + or os.getenv("AZURE_API_BASE") + ) + + if self.api_version == "2024-06-01": + env_version = os.getenv("AZURE_API_VERSION") + if env_version: + self.api_version = env_version if not self.api_key: raise ValueError( @@ -120,36 +132,38 @@ def __init__( "Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter." ) - # Validate and potentially fix Azure OpenAI endpoint URL - self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model) + self.endpoint = self._validate_and_fix_endpoint(self.endpoint, self.model) - # Build client kwargs - client_kwargs = { + client_kwargs: dict[str, Any] = { "endpoint": self.endpoint, "credential": AzureKeyCredential(self.api_key), } - # Add api_version if specified (primarily for Azure OpenAI endpoints) if self.api_version: client_kwargs["api_version"] = self.api_version - self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type] + self.client = ChatCompletionsClient(**client_kwargs) - self.top_p = top_p - self.frequency_penalty = frequency_penalty - self.presence_penalty = presence_penalty - self.max_tokens = max_tokens - self.stream = stream - - self.is_openai_model = any( - prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"] + self._is_openai_model = any( + prefix in self.model.lower() for prefix in ["gpt-", "o1-", "text-"] ) - - self.is_azure_openai_endpoint = ( + self._is_azure_openai_endpoint = ( "openai.azure.com" in self.endpoint and "/openai/deployments/" in self.endpoint ) + return self + + @property + def is_openai_model(self) -> bool: + """Check if model is an OpenAI model.""" + return self._is_openai_model + + @property + def is_azure_openai_endpoint(self) -> bool: + """Check if endpoint is an Azure OpenAI endpoint.""" + return self._is_azure_openai_endpoint + def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str: """Validate and fix Azure endpoint URL format. diff --git a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py index f67414c639..ed6738c4b1 100644 --- a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py @@ -5,8 +5,8 @@ import os from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast -from pydantic import BaseModel, ConfigDict -from typing_extensions import Required +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from typing_extensions import Required, Self from crewai.events.types.llm_events import LLMCallType from crewai.llm.base_llm import BaseLLM @@ -32,8 +32,6 @@ ToolTypeDef, ) - from crewai.llm.hooks.base import BaseInterceptor - try: from boto3.session import Session @@ -143,76 +141,86 @@ class BedrockCompletion(BaseLLM): - Complete streaming event handling (messageStart, contentBlockStart, etc.) - Response metadata and trace information capture - Model-specific conversation format handling (e.g., Cohere requirements) + + Attributes: + model: The Bedrock model ID to use + aws_access_key_id: AWS access key (defaults to environment variable) + aws_secret_access_key: AWS secret key (defaults to environment variable) + aws_session_token: AWS session token for temporary credentials + region_name: AWS region name + max_tokens: Maximum tokens to generate + top_p: Nucleus sampling parameter + top_k: Top-k sampling parameter (Claude models only) + stop_sequences: List of sequences that stop generation + stream: Whether to use streaming responses + guardrail_config: Guardrail configuration for content filtering + additional_model_request_fields: Model-specific request parameters + additional_model_response_field_paths: Custom response field paths + interceptor: HTTP interceptor (not yet supported for Bedrock) """ - model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,)) + model_config: ClassVar[ConfigDict] = ConfigDict( + ignored_types=(property,), arbitrary_types_allowed=True + ) - def __init__( - self, - model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0", - aws_access_key_id: str | None = None, - aws_secret_access_key: str | None = None, - aws_session_token: str | None = None, - region_name: str = "us-east-1", - temperature: float | None = None, - max_tokens: int | None = None, - top_p: float | None = None, - top_k: int | None = None, - stop_sequences: Sequence[str] | None = None, - stream: bool = False, - guardrail_config: dict[str, Any] | None = None, - additional_model_request_fields: dict[str, Any] | None = None, - additional_model_response_field_paths: list[str] | None = None, - interceptor: BaseInterceptor[Any, Any] | None = None, - **kwargs: Any, - ) -> None: - """Initialize AWS Bedrock completion client. - - Args: - model: The Bedrock model ID to use - aws_access_key_id: AWS access key (defaults to environment variable) - aws_secret_access_key: AWS secret key (defaults to environment variable) - aws_session_token: AWS session token for temporary credentials - region_name: AWS region name - temperature: Sampling temperature for response generation - max_tokens: Maximum tokens to generate - top_p: Nucleus sampling parameter - top_k: Top-k sampling parameter (Claude models only) - stop_sequences: List of sequences that stop generation - stream: Whether to use streaming responses - guardrail_config: Guardrail configuration for content filtering - additional_model_request_fields: Model-specific request parameters - additional_model_response_field_paths: Custom response field paths - interceptor: HTTP interceptor (not yet supported for Bedrock). - **kwargs: Additional parameters - """ - if interceptor is not None: + aws_access_key_id: str | None = Field( + default=None, description="AWS access key (defaults to environment variable)" + ) + aws_secret_access_key: str | None = Field( + default=None, description="AWS secret key (defaults to environment variable)" + ) + aws_session_token: str | None = Field( + default=None, description="AWS session token for temporary credentials" + ) + region_name: str = Field(default="us-east-1", description="AWS region name") + max_tokens: int | None = Field( + default=None, description="Maximum tokens to generate" + ) + top_p: float | None = Field(default=None, description="Nucleus sampling parameter") + top_k: int | None = Field( + default=None, description="Top-k sampling parameter (Claude models only)" + ) + stream: bool = Field( + default=False, description="Whether to use streaming responses" + ) + guardrail_config: dict[str, Any] | None = Field( + default=None, description="Guardrail configuration for content filtering" + ) + additional_model_request_fields: dict[str, Any] | None = Field( + default=None, description="Model-specific request parameters" + ) + additional_model_response_field_paths: list[str] | None = Field( + default=None, description="Custom response field paths" + ) + interceptor: Any = Field( + default=None, description="HTTP interceptor (not yet supported for Bedrock)" + ) + client: Any = Field( + default=None, exclude=True, description="Bedrock client instance" + ) + + _is_claude_model: bool = PrivateAttr(default=False) + _supports_tools: bool = PrivateAttr(default=True) + _supports_streaming: bool = PrivateAttr(default=True) + _model_id: str = PrivateAttr() + + @model_validator(mode="after") + def setup_client(self) -> Self: + """Initialize the Bedrock client and validate configuration.""" + if self.interceptor is not None: raise NotImplementedError( "HTTP interceptors are not yet supported for AWS Bedrock provider. " "Interceptors are currently supported for OpenAI and Anthropic providers only." ) - # Extract provider from kwargs to avoid duplicate argument - kwargs.pop("provider", None) - - super().__init__( - model=model, - temperature=temperature, - stop=stop_sequences or [], - provider="bedrock", - **kwargs, - ) - - # Initialize Bedrock client with proper configuration session = Session( - aws_access_key_id=aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), - aws_secret_access_key=aws_secret_access_key + aws_access_key_id=self.aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), + aws_secret_access_key=self.aws_secret_access_key or os.getenv("AWS_SECRET_ACCESS_KEY"), - aws_session_token=aws_session_token or os.getenv("AWS_SESSION_TOKEN"), - region_name=region_name, + aws_session_token=self.aws_session_token or os.getenv("AWS_SESSION_TOKEN"), + region_name=self.region_name, ) - # Configure client with timeouts and retries following AWS best practices config = Config( read_timeout=300, retries={ @@ -223,53 +231,33 @@ def __init__( ) self.client = session.client("bedrock-runtime", config=config) - self.region_name = region_name - - # Store completion parameters - self.max_tokens = max_tokens - self.top_p = top_p - self.top_k = top_k - self.stream = stream - self.stop_sequences = stop_sequences or [] - - # Store advanced features (optional) - self.guardrail_config = guardrail_config - self.additional_model_request_fields = additional_model_request_fields - self.additional_model_response_field_paths = ( - additional_model_response_field_paths - ) - # Model-specific settings - self.is_claude_model = "claude" in model.lower() - self.supports_tools = True # Converse API supports tools for most models - self.supports_streaming = True - - # Handle inference profiles for newer models - self.model_id = model - - # @property - # def stop(self) -> list[str]: # type: ignore[misc] - # """Get stop sequences sent to the API.""" - # return list(self.stop_sequences) - - # @stop.setter - # def stop(self, value: Sequence[str] | str | None) -> None: - # """Set stop sequences. - # - # Synchronizes stop_sequences to ensure values set by CrewAgentExecutor - # are properly sent to the Bedrock API. - # - # Args: - # value: Stop sequences as a Sequence, single string, or None - # """ - # if value is None: - # self.stop_sequences = [] - # elif isinstance(value, str): - # self.stop_sequences = [value] - # elif isinstance(value, Sequence): - # self.stop_sequences = list(value) - # else: - # self.stop_sequences = [] + self._is_claude_model = "claude" in self.model.lower() + self._supports_tools = True + self._supports_streaming = True + self._model_id = self.model + + return self + + @property + def is_claude_model(self) -> bool: + """Check if model is a Claude model.""" + return self._is_claude_model + + @property + def supports_tools(self) -> bool: + """Check if model supports tools.""" + return self._supports_tools + + @property + def supports_streaming(self) -> bool: + """Check if model supports streaming.""" + return self._supports_streaming + + @property + def model_id(self) -> str: + """Get the model ID.""" + return self._model_id def call( self, @@ -559,7 +547,7 @@ def _handle_streaming_converse( "Sequence[MessageTypeDef | MessageOutputTypeDef]", cast(object, messages), ), - **body, # type: ignore[arg-type] + **body, ) stream = response.get("stream") @@ -821,8 +809,8 @@ def _get_inference_config(self) -> EnhancedInferenceConfigurationTypeDef: config["temperature"] = float(self.temperature) if self.top_p is not None: config["topP"] = float(self.top_p) - if self.stop_sequences: - config["stopSequences"] = self.stop_sequences + if self.stop: + config["stopSequences"] = self.stop if self.is_claude_model and self.top_k is not None: # top_k is supported by Claude models diff --git a/lib/crewai/src/crewai/llm/providers/gemini/completion.py b/lib/crewai/src/crewai/llm/providers/gemini/completion.py index 263309910f..34bff25083 100644 --- a/lib/crewai/src/crewai/llm/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llm/providers/gemini/completion.py @@ -2,12 +2,12 @@ import os from typing import Any, ClassVar, cast -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType from crewai.llm.base_llm import BaseLLM from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES -from crewai.llm.hooks.base import BaseInterceptor from crewai.llm.providers.utils.common import safe_tool_conversion from crewai.utilities.agent_utils import is_context_length_exceeded from crewai.utilities.exceptions.context_window_exceeding_exception import ( @@ -31,108 +31,124 @@ class GeminiCompletion(BaseLLM): This class provides direct integration with the Google Gen AI Python SDK, offering native function calling, streaming support, and proper Gemini formatting. + + Attributes: + model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro') + project: Google Cloud project ID (for Vertex AI) + location: Google Cloud location (for Vertex AI, defaults to 'us-central1') + top_p: Nucleus sampling parameter + top_k: Top-k sampling parameter + max_output_tokens: Maximum tokens in response + stop_sequences: Stop sequences + stream: Enable streaming responses + safety_settings: Safety filter settings + client_params: Additional parameters for Google Gen AI Client constructor + interceptor: HTTP interceptor (not yet supported for Gemini) """ - model_config: ClassVar[ConfigDict] = ConfigDict(ignored_types=(property,)) + model_config: ClassVar[ConfigDict] = ConfigDict( + ignored_types=(property,), arbitrary_types_allowed=True + ) + + project: str | None = Field( + default=None, description="Google Cloud project ID (for Vertex AI)" + ) + location: str = Field( + default="us-central1", + description="Google Cloud location (for Vertex AI, defaults to 'us-central1')", + ) + top_p: float | None = Field(default=None, description="Nucleus sampling parameter") + top_k: int | None = Field(default=None, description="Top-k sampling parameter") + max_output_tokens: int | None = Field( + default=None, description="Maximum tokens in response" + ) + stream: bool = Field(default=False, description="Enable streaming responses") + safety_settings: dict[str, Any] = Field( + default_factory=dict, description="Safety filter settings" + ) + client_params: dict[str, Any] = Field( + default_factory=dict, + description="Additional parameters for Google Gen AI Client constructor", + ) + interceptor: Any = Field( + default=None, description="HTTP interceptor (not yet supported for Gemini)" + ) + client: Any = Field( + default=None, exclude=True, description="Gemini client instance" + ) + + _is_gemini_2: bool = PrivateAttr(default=False) + _is_gemini_1_5: bool = PrivateAttr(default=False) + _supports_tools: bool = PrivateAttr(default=False) + + @property + def stop_sequences(self) -> list[str]: + """Get stop sequences as a list. + + This property provides access to stop sequences in Gemini's native format + while maintaining synchronization with the base class's stop attribute. + """ + if self.stop is None: + return [] + if isinstance(self.stop, str): + return [self.stop] + return self.stop - def __init__( - self, - model: str = "gemini-2.0-flash-001", - api_key: str | None = None, - project: str | None = None, - location: str | None = None, - temperature: float | None = None, - top_p: float | None = None, - top_k: int | None = None, - max_output_tokens: int | None = None, - stop_sequences: list[str] | None = None, - stream: bool = False, - safety_settings: dict[str, Any] | None = None, - client_params: dict[str, Any] | None = None, - interceptor: BaseInterceptor[Any, Any] | None = None, - **kwargs: Any, - ): - """Initialize Google Gemini chat completion client. + @stop_sequences.setter + def stop_sequences(self, value: list[str] | str | None) -> None: + """Set stop sequences, synchronizing with the stop attribute. Args: - model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro') - api_key: Google API key (defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var) - project: Google Cloud project ID (for Vertex AI) - location: Google Cloud location (for Vertex AI, defaults to 'us-central1') - temperature: Sampling temperature (0-2) - top_p: Nucleus sampling parameter - top_k: Top-k sampling parameter - max_output_tokens: Maximum tokens in response - stop_sequences: Stop sequences - stream: Enable streaming responses - safety_settings: Safety filter settings - client_params: Additional parameters to pass to the Google Gen AI Client constructor. - Supports parameters like http_options, credentials, debug_config, etc. - interceptor: HTTP interceptor (not yet supported for Gemini). - **kwargs: Additional parameters + value: Stop sequences as a list, string, or None """ - if interceptor is not None: + self.stop = value + + @model_validator(mode="after") + def setup_client(self) -> Self: + """Initialize the Gemini client and validate configuration.""" + if self.interceptor is not None: raise NotImplementedError( "HTTP interceptors are not yet supported for Google Gemini provider. " "Interceptors are currently supported for OpenAI and Anthropic providers only." ) - super().__init__( - model=model, temperature=temperature, stop=stop_sequences or [], **kwargs - ) + if self.api_key is None: + self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") - # Store client params for later use - self.client_params = client_params or {} + if self.project is None: + self.project = os.getenv("GOOGLE_CLOUD_PROJECT") - # Get API configuration with environment variable fallbacks - self.api_key = ( - api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") - ) - self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT") - self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1" + if self.location == "us-central1": + env_location = os.getenv("GOOGLE_CLOUD_LOCATION") + if env_location: + self.location = env_location use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true" self.client = self._initialize_client(use_vertexai) - # Store completion parameters - self.top_p = top_p - self.top_k = top_k - self.max_output_tokens = max_output_tokens - self.stream = stream - self.safety_settings = safety_settings or {} - self.stop_sequences = stop_sequences or [] - - # Model-specific settings - self.is_gemini_2 = "gemini-2" in model.lower() - self.is_gemini_1_5 = "gemini-1.5" in model.lower() - self.supports_tools = self.is_gemini_1_5 or self.is_gemini_2 - - # @property - # def stop(self) -> list[str]: # type: ignore[misc] - # """Get stop sequences sent to the API.""" - # return self.stop_sequences - - # @stop.setter - # def stop(self, value: list[str] | str | None) -> None: - # """Set stop sequences. - # - # Synchronizes stop_sequences to ensure values set by CrewAgentExecutor - # are properly sent to the Gemini API. - # - # Args: - # value: Stop sequences as a list, single string, or None - # """ - # if value is None: - # self.stop_sequences = [] - # elif isinstance(value, str): - # self.stop_sequences = [value] - # elif isinstance(value, list): - # self.stop_sequences = value - # else: - # self.stop_sequences = [] - - def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # type: ignore[no-any-unimported] + self._is_gemini_2 = "gemini-2" in self.model.lower() + self._is_gemini_1_5 = "gemini-1.5" in self.model.lower() + self._supports_tools = self._is_gemini_1_5 or self._is_gemini_2 + + return self + + @property + def is_gemini_2(self) -> bool: + """Check if model is Gemini 2.""" + return self._is_gemini_2 + + @property + def is_gemini_1_5(self) -> bool: + """Check if model is Gemini 1.5.""" + return self._is_gemini_1_5 + + @property + def supports_tools(self) -> bool: + """Check if model supports tools.""" + return self._supports_tools + + def _initialize_client(self, use_vertexai: bool = False) -> Any: """Initialize the Google Gen AI client with proper parameter handling. Args: @@ -154,12 +170,9 @@ def _initialize_client(self, use_vertexai: bool = False) -> genai.Client: # typ "location": self.location, } ) - client_params.pop("api_key", None) - elif self.api_key: client_params["api_key"] = self.api_key - client_params.pop("vertexai", None) client_params.pop("project", None) client_params.pop("location", None) @@ -188,7 +201,6 @@ def _get_client_params(self) -> dict[str, Any]: and hasattr(self.client, "vertexai") and self.client.vertexai ): - # Vertex AI configuration params.update( { "vertexai": True, @@ -300,15 +312,12 @@ def _prepare_generation_config( # type: ignore[no-any-unimported] self.tools = tools config_params = {} - # Add system instruction if present if system_instruction: - # Convert system instruction to Content format system_content = types.Content( role="user", parts=[types.Part.from_text(text=system_instruction)] ) config_params["system_instruction"] = system_content - # Add generation config parameters if self.temperature is not None: config_params["temperature"] = self.temperature if self.top_p is not None: @@ -317,14 +326,13 @@ def _prepare_generation_config( # type: ignore[no-any-unimported] config_params["top_k"] = self.top_k if self.max_output_tokens is not None: config_params["max_output_tokens"] = self.max_output_tokens - if self.stop_sequences: - config_params["stop_sequences"] = self.stop_sequences + if self.stop: + config_params["stop_sequences"] = self.stop if response_model: config_params["response_mime_type"] = "application/json" config_params["response_schema"] = response_model.model_json_schema() - # Handle tools for supported models if tools and self.supports_tools: config_params["tools"] = self._convert_tools_for_interference(tools) @@ -347,7 +355,6 @@ def _convert_tools_for_interference( # type: ignore[no-any-unimported] description=description, ) - # Add parameters if present - ensure parameters is a dict if parameters and isinstance(parameters, dict): function_declaration.parameters = parameters @@ -383,16 +390,12 @@ def _format_messages_for_gemini( # type: ignore[no-any-unimported] content = message.get("content", "") if role == "system": - # Extract system instruction - Gemini handles it separately if system_instruction: system_instruction += f"\n\n{content}" else: system_instruction = cast(str, content) else: - # Convert role for Gemini (assistant -> model) gemini_role = "model" if role == "assistant" else "user" - - # Create Content object gemini_content = types.Content( role=gemini_role, parts=[types.Part.from_text(text=content)] ) @@ -509,13 +512,11 @@ def _handle_streaming_completion( # type: ignore[no-any-unimported] else {}, } - # Handle completed function calls if function_calls and available_functions: for call_data in function_calls.values(): function_name = call_data["name"] function_args = call_data["args"] - # Execute tool result = self._handle_tool_execution( function_name=function_name, function_args=function_args, @@ -575,13 +576,11 @@ def get_context_window_size(self) -> int: "gemma-3-27b": 128000, } - # Find the best match for the model name for model_prefix, size in context_windows.items(): if self.model.startswith(model_prefix): return int(size * CONTEXT_WINDOW_USAGE_RATIO) - # Default context window size for Gemini models - return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens + return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) def _extract_token_usage(self, response: dict[str, Any]) -> dict[str, Any]: """Extract token usage from Gemini response.""" diff --git a/lib/crewai/src/crewai/llm/providers/openai/completion.py b/lib/crewai/src/crewai/llm/providers/openai/completion.py index f9f65c8b13..b9fcc99c7c 100644 --- a/lib/crewai/src/crewai/llm/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llm/providers/openai/completion.py @@ -11,7 +11,8 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import ChoiceDelta -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator +from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType from crewai.llm.base_llm import BaseLLM @@ -73,26 +74,18 @@ class OpenAICompletion(BaseLLM): ) reasoning_effort: str | None = Field(None, description="Reasoning effort level") - # Internal state client: OpenAI = Field( default_factory=OpenAI, exclude=True, description="OpenAI client instance" ) is_o1_model: bool = Field(False, description="Whether this is an O1 model") is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model") - def model_post_init(self, __context: Any) -> None: - """Initialize OpenAI client after model initialization. - - Args: - __context: Pydantic context - """ - super().model_post_init(__context) - - # Set API key from environment if not provided + @model_validator(mode="after") + def setup_client(self) -> Self: + """Initialize OpenAI client after model validation.""" if self.api_key is None: self.api_key = os.getenv("OPENAI_API_KEY") - # Initialize client client_config = self._get_client_params() if self.interceptor: transport = HTTPTransport(interceptor=self.interceptor) @@ -101,10 +94,11 @@ def model_post_init(self, __context: Any) -> None: self.client = OpenAI(**client_config) - # Set model flags self.is_o1_model = "o1" in self.model.lower() self.is_gpt4_model = "gpt-4" in self.model.lower() + return self + def _get_client_params(self) -> dict[str, Any]: """Get OpenAI client parameters.""" diff --git a/lib/crewai/src/crewai/tasks/hallucination_guardrail.py b/lib/crewai/src/crewai/tasks/hallucination_guardrail.py index dd000a83cd..07b4653663 100644 --- a/lib/crewai/src/crewai/tasks/hallucination_guardrail.py +++ b/lib/crewai/src/crewai/tasks/hallucination_guardrail.py @@ -8,7 +8,7 @@ from typing import Any -from crewai.llm import LLM +from crewai.llm.core import LLM from crewai.tasks.task_output import TaskOutput from crewai.utilities.logger import Logger diff --git a/lib/crewai/src/crewai/tools/tool_usage.py b/lib/crewai/src/crewai/tools/tool_usage.py index 6f0e92cb8d..69791291a1 100644 --- a/lib/crewai/src/crewai/tools/tool_usage.py +++ b/lib/crewai/src/crewai/tools/tool_usage.py @@ -36,7 +36,7 @@ from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.tools_handler import ToolsHandler from crewai.lite_agent import LiteAgent - from crewai.llm import LLM + from crewai.llm.core import LLM from crewai.task import Task diff --git a/lib/crewai/tests/test_llm.py b/lib/crewai/tests/test_llm.py index 3d8a1282e2..16175bf0fc 100644 --- a/lib/crewai/tests/test_llm.py +++ b/lib/crewai/tests/test_llm.py @@ -11,7 +11,7 @@ ToolUsageFinishedEvent, ToolUsageStartedEvent, ) -from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM +from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM from crewai.utilities.token_counter_callback import TokenCalcHandler from pydantic import BaseModel import pytest @@ -229,7 +229,7 @@ class DummyResponse(BaseModel): a: int # Patch supports_response_schema to simulate a supported model. - with patch("crewai.llm.supports_response_schema", return_value=True): + with patch("crewai.llm.core.supports_response_schema", return_value=True): llm = LLM( model="openrouter/deepseek/deepseek-chat", response_format=DummyResponse ) @@ -242,7 +242,7 @@ class DummyResponse(BaseModel): a: int # Patch supports_response_schema to simulate an unsupported model. - with patch("crewai.llm.supports_response_schema", return_value=False): + with patch("crewai.llm.core.supports_response_schema", return_value=False): llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True) with pytest.raises(ValueError) as excinfo: llm._validate_call_params() @@ -342,7 +342,7 @@ def test_context_window_validation(): # Test invalid window size with pytest.raises(ValueError) as excinfo: with patch.dict( - "crewai.llm.LLM_CONTEXT_WINDOW_SIZES", + "crewai.llm.core.LLM_CONTEXT_WINDOW_SIZES", {"test-model": 500}, # Below minimum clear=True, ): @@ -702,8 +702,8 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm): def test_native_provider_raises_error_when_supported_but_fails(): """Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error.""" - with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]): - with patch("crewai.llm.LLM._get_native_provider") as mock_get_native: + with patch("crewai.llm.internal.meta.SUPPORTED_NATIVE_PROVIDERS", ["openai"]): + with patch("crewai.llm.internal.meta.LLMMeta._get_native_provider") as mock_get_native: # Mock that provider exists but throws an error when instantiated mock_provider = MagicMock() mock_provider.side_effect = ValueError("Native provider initialization failed") @@ -718,7 +718,7 @@ def test_native_provider_raises_error_when_supported_but_fails(): def test_native_provider_falls_back_to_litellm_when_not_in_supported_list(): """Test that when a provider is not in SUPPORTED_NATIVE_PROVIDERS, we fall back to LiteLLM.""" - with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai", "anthropic"]): + with patch("crewai.llm.internal.meta.SUPPORTED_NATIVE_PROVIDERS", ["openai", "anthropic"]): # Using a provider not in the supported list llm = LLM(model="groq/llama-3.1-70b-versatile", is_litellm=False) diff --git a/logs.txt b/logs.txt deleted file mode 100644 index 6aaa37c3a2..0000000000 --- a/logs.txt +++ /dev/null @@ -1,20 +0,0 @@ -lib/crewai/src/crewai/agent/core.py:901: error: Argument 1 has incompatible type "ToolFilterContext"; expected "dict[str, Any]" [arg-type] -lib/crewai/src/crewai/agent/core.py:901: note: Error code "arg-type" not covered by "type: ignore" comment -lib/crewai/src/crewai/agent/core.py:905: error: Argument 1 has incompatible type "dict[str, Any]"; expected "ToolFilterContext" [arg-type] -lib/crewai/src/crewai/agent/core.py:905: note: Error code "arg-type" not covered by "type: ignore" comment -lib/crewai/src/crewai/agent/core.py:996: error: Returning Any from function declared to return "dict[str, dict[str, Any]]" [no-any-return] -lib/crewai/src/crewai/agent/core.py:1157: error: Incompatible types in assignment (expression has type "tuple[UnionType, None]", target has type "tuple[type, Any]") [assignment] -lib/crewai/src/crewai/agent/core.py:1183: error: Argument 1 to "append" of "list" has incompatible type "type"; expected "type[str]" [arg-type] -lib/crewai/src/crewai/agent/core.py:1188: error: Incompatible types in assignment (expression has type "UnionType", variable has type "type[str]") [assignment] -lib/crewai/src/crewai/agent/core.py:1201: error: Argument 1 to "get" of "dict" has incompatible type "Any | None"; expected "str" [arg-type] -Found 7 errors in 1 file (checked 4 source files) -Success: no issues found in 4 source files -lib/crewai/src/crewai/llm/providers/gemini/completion.py:111: error: BaseModel field may only be overridden by another field [misc] -Found 1 error in 1 file (checked 4 source files) -Success: no issues found in 4 source files -lib/crewai/src/crewai/llm/providers/anthropic/completion.py:101: error: BaseModel field may only be overridden by another field [misc] -Found 1 error in 1 file (checked 4 source files) -lib/crewai/src/crewai/llm/providers/bedrock/completion.py:250: error: BaseModel field may only be overridden by another field [misc] -Found 1 error in 1 file (checked 4 source files) - -uv-lock..............................................(no files to check)Skipped From 6fb13ee3e0c7e2bc5bc4e7d8bc14a02833c472de Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Tue, 11 Nov 2025 00:02:34 -0500 Subject: [PATCH 4/7] chore: fix attr ref --- lib/crewai/src/crewai/llm/base_llm.py | 28 ++++++-- lib/crewai/src/crewai/llm/constants.py | 29 +++++++++ lib/crewai/src/crewai/llm/internal/meta.py | 64 +++++++++---------- .../llm/providers/anthropic/completion.py | 12 +--- .../crewai/llm/providers/azure/completion.py | 3 - .../llm/providers/bedrock/completion.py | 4 -- .../crewai/llm/providers/gemini/completion.py | 26 -------- .../tests/llms/anthropic/test_anthropic.py | 14 ++-- lib/crewai/tests/llms/bedrock/test_bedrock.py | 14 ++-- lib/crewai/tests/llms/google/test_google.py | 11 ++-- 10 files changed, 98 insertions(+), 107 deletions(-) diff --git a/lib/crewai/src/crewai/llm/base_llm.py b/lib/crewai/src/crewai/llm/base_llm.py index f60ce500e6..e86fd817ba 100644 --- a/lib/crewai/src/crewai/llm/base_llm.py +++ b/lib/crewai/src/crewai/llm/base_llm.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Final import httpx -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from crewai.events.event_bus import crewai_event_bus from crewai.events.types.llm_events import ( @@ -62,11 +62,10 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): model: The model identifier/name. temperature: Optional temperature setting for response generation. stop: A list of stop sequences that the LLM should use to stop generation. - additional_params: Additional provider-specific parameters. """ model_config: ClassVar[ConfigDict] = ConfigDict( - arbitrary_types_allowed=True, extra="allow" + extra="allow", populate_by_name=True ) # Core fields @@ -82,7 +81,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): stop: list[str] = Field( default_factory=list, description="Stop sequences for generation", - validation_alias="stop_sequences", + alias="stop_sequences", ) # Internal fields @@ -90,7 +89,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): default=False, description="Whether this instance uses LiteLLM" ) interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = Field( - None, description="HTTP request/response interceptor" + default=None, description="HTTP request/response interceptor" ) _token_usage: dict[str, int] = { "total_tokens": 0, @@ -100,6 +99,25 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): "cached_prompt_tokens": 0, } + @field_validator("stop", mode="before") + @classmethod + def _normalize_stop(cls, value: Any) -> list[str]: + """Normalize stop sequences to a list. + + Args: + value: Stop sequences as string, list, or None + + Returns: + Normalized list of stop sequences + """ + if value is None: + return [] + if isinstance(value, str): + return [value] + if isinstance(value, list): + return value + return [] + @model_validator(mode="before") @classmethod def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]: diff --git a/lib/crewai/src/crewai/llm/constants.py b/lib/crewai/src/crewai/llm/constants.py index 2765a94582..880e14e9c4 100644 --- a/lib/crewai/src/crewai/llm/constants.py +++ b/lib/crewai/src/crewai/llm/constants.py @@ -1,6 +1,31 @@ from typing import Literal, TypeAlias +SupportedNativeProviders: TypeAlias = Literal[ + "openai", + "anthropic", + "claude", + "azure", + "azure_openai", + "google", + "gemini", + "bedrock", + "aws", +] + +SUPPORTED_NATIVE_PROVIDERS: list[SupportedNativeProviders] = [ + "openai", + "anthropic", + "claude", + "azure", + "azure_openai", + "google", + "gemini", + "bedrock", + "aws", +] + + OpenAIModels: TypeAlias = Literal[ "gpt-3.5-turbo", "gpt-3.5-turbo-0125", @@ -556,3 +581,7 @@ "qwen.qwen3-coder-30b-a3b-v1:0", "twelvelabs.pegasus-1-2-v1:0", ] + +SupportedModels: TypeAlias = ( + OpenAIModels | AnthropicModels | GeminiModels | AzureModels | BedrockModels +) diff --git a/lib/crewai/src/crewai/llm/internal/meta.py b/lib/crewai/src/crewai/llm/internal/meta.py index f91fad96d9..f210ab742c 100644 --- a/lib/crewai/src/crewai/llm/internal/meta.py +++ b/lib/crewai/src/crewai/llm/internal/meta.py @@ -7,23 +7,20 @@ from __future__ import annotations import logging -from typing import Any +from typing import Any, cast from pydantic._internal._model_construction import ModelMetaclass - -# Provider constants imported from crewai.llm.constants -SUPPORTED_NATIVE_PROVIDERS: list[str] = [ - "openai", - "anthropic", - "claude", - "azure", - "azure_openai", - "google", - "gemini", - "bedrock", - "aws", -] +from crewai.llm.constants import ( + ANTHROPIC_MODELS, + AZURE_MODELS, + BEDROCK_MODELS, + GEMINI_MODELS, + OPENAI_MODELS, + SUPPORTED_NATIVE_PROVIDERS, + SupportedModels, + SupportedNativeProviders, +) class LLMMeta(ModelMetaclass): @@ -49,25 +46,31 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805 if cls.__name__ != "LLM": return super().__call__(*args, **kwargs) - model = kwargs.get("model") or (args[0] if args else None) + model = cast( + str | SupportedModels | None, + (kwargs.get("model") or (args[0] if args else None)), + ) is_litellm = kwargs.get("is_litellm", False) if not model or not isinstance(model, str): raise ValueError("Model must be a non-empty string") if args and not kwargs.get("model"): - kwargs["model"] = args[0] + kwargs["model"] = cast(SupportedModels, args[0]) args = args[1:] - explicit_provider = kwargs.get("provider") + explicit_provider = cast(SupportedNativeProviders, kwargs.get("provider")) if explicit_provider: provider = explicit_provider use_native = True model_string = model elif "/" in model: - prefix, _, model_part = model.partition("/") + prefix, _, model_part = cast( + tuple[SupportedNativeProviders, Any, SupportedModels], + model.partition("/"), + ) - provider_mapping = { + provider_mapping: dict[str, SupportedNativeProviders] = { "openai": "openai", "anthropic": "anthropic", "claude": "anthropic", @@ -122,7 +125,9 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805 return super().__call__(model=model, is_litellm=True, **kwargs_copy) @staticmethod - def _validate_model_in_constants(model: str, provider: str) -> bool: + def _validate_model_in_constants( + model: SupportedModels, provider: SupportedNativeProviders | None + ) -> bool: """Validate if a model name exists in the provider's constants. Args: @@ -132,12 +137,6 @@ def _validate_model_in_constants(model: str, provider: str) -> bool: Returns: True if the model exists in the provider's constants, False otherwise """ - from crewai.llm.constants import ( - ANTHROPIC_MODELS, - BEDROCK_MODELS, - GEMINI_MODELS, - OPENAI_MODELS, - ) if provider == "openai": return model in OPENAI_MODELS @@ -158,7 +157,9 @@ def _validate_model_in_constants(model: str, provider: str) -> bool: return False @staticmethod - def _infer_provider_from_model(model: str) -> str: + def _infer_provider_from_model( + model: SupportedModels | str, + ) -> SupportedNativeProviders: """Infer the provider from the model name. Args: @@ -167,13 +168,6 @@ def _infer_provider_from_model(model: str) -> str: Returns: The inferred provider name, defaults to "openai" """ - from crewai.llm.constants import ( - ANTHROPIC_MODELS, - AZURE_MODELS, - BEDROCK_MODELS, - GEMINI_MODELS, - OPENAI_MODELS, - ) if model in OPENAI_MODELS: return "openai" @@ -193,7 +187,7 @@ def _infer_provider_from_model(model: str) -> str: return "openai" @staticmethod - def _get_native_provider(provider: str) -> type | None: + def _get_native_provider(provider: SupportedNativeProviders | None) -> type | None: """Get native provider class if available. Args: diff --git a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py index dd13c0f5e1..05366bbf4c 100644 --- a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py @@ -48,11 +48,6 @@ class AnthropicCompletion(BaseLLM): client_params: Additional parameters for the Anthropic client interceptor: HTTP interceptor for modifying requests/responses at transport level """ - - model_config: ClassVar[ConfigDict] = ConfigDict( - ignored_types=(property,), arbitrary_types_allowed=True - ) - base_url: str | None = Field( default=None, description="Custom base URL for Anthropic API" ) @@ -68,11 +63,8 @@ class AnthropicCompletion(BaseLLM): client_params: dict[str, Any] | None = Field( default=None, description="Additional Anthropic client parameters" ) - interceptor: Any = Field( - default=None, description="HTTP interceptor for request/response modification" - ) - client: Any = Field( - default=None, exclude=True, description="Anthropic client instance" + client: Anthropic = Field( + default_factory=Anthropic, exclude=True, description="Anthropic client instance" ) _is_claude_3: bool = PrivateAttr(default=False) diff --git a/lib/crewai/src/crewai/llm/providers/azure/completion.py b/lib/crewai/src/crewai/llm/providers/azure/completion.py index 6076b9f4d1..3a4f68d085 100644 --- a/lib/crewai/src/crewai/llm/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llm/providers/azure/completion.py @@ -91,9 +91,6 @@ class AzureCompletion(BaseLLM): default=None, description="Maximum tokens in response" ) stream: bool = Field(default=False, description="Enable streaming responses") - interceptor: Any = Field( - default=None, description="HTTP interceptor (not yet supported for Azure)" - ) client: Any = Field(default=None, exclude=True, description="Azure client instance") _is_openai_model: bool = PrivateAttr(default=False) diff --git a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py index ed6738c4b1..58495a1518 100644 --- a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py @@ -151,7 +151,6 @@ class BedrockCompletion(BaseLLM): max_tokens: Maximum tokens to generate top_p: Nucleus sampling parameter top_k: Top-k sampling parameter (Claude models only) - stop_sequences: List of sequences that stop generation stream: Whether to use streaming responses guardrail_config: Guardrail configuration for content filtering additional_model_request_fields: Model-specific request parameters @@ -192,9 +191,6 @@ class BedrockCompletion(BaseLLM): additional_model_response_field_paths: list[str] | None = Field( default=None, description="Custom response field paths" ) - interceptor: Any = Field( - default=None, description="HTTP interceptor (not yet supported for Bedrock)" - ) client: Any = Field( default=None, exclude=True, description="Bedrock client instance" ) diff --git a/lib/crewai/src/crewai/llm/providers/gemini/completion.py b/lib/crewai/src/crewai/llm/providers/gemini/completion.py index 34bff25083..f2b1916560 100644 --- a/lib/crewai/src/crewai/llm/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llm/providers/gemini/completion.py @@ -39,7 +39,6 @@ class GeminiCompletion(BaseLLM): top_p: Nucleus sampling parameter top_k: Top-k sampling parameter max_output_tokens: Maximum tokens in response - stop_sequences: Stop sequences stream: Enable streaming responses safety_settings: Safety filter settings client_params: Additional parameters for Google Gen AI Client constructor @@ -70,9 +69,6 @@ class GeminiCompletion(BaseLLM): default_factory=dict, description="Additional parameters for Google Gen AI Client constructor", ) - interceptor: Any = Field( - default=None, description="HTTP interceptor (not yet supported for Gemini)" - ) client: Any = Field( default=None, exclude=True, description="Gemini client instance" ) @@ -81,28 +77,6 @@ class GeminiCompletion(BaseLLM): _is_gemini_1_5: bool = PrivateAttr(default=False) _supports_tools: bool = PrivateAttr(default=False) - @property - def stop_sequences(self) -> list[str]: - """Get stop sequences as a list. - - This property provides access to stop sequences in Gemini's native format - while maintaining synchronization with the base class's stop attribute. - """ - if self.stop is None: - return [] - if isinstance(self.stop, str): - return [self.stop] - return self.stop - - @stop_sequences.setter - def stop_sequences(self, value: list[str] | str | None) -> None: - """Set stop sequences, synchronizing with the stop attribute. - - Args: - value: Stop sequences as a list, string, or None - """ - self.stop = value - @model_validator(mode="after") def setup_client(self) -> Self: """Initialize the Gemini client and validate configuration.""" diff --git a/lib/crewai/tests/llms/anthropic/test_anthropic.py b/lib/crewai/tests/llms/anthropic/test_anthropic.py index c0d2957f14..5a91b2e1e3 100644 --- a/lib/crewai/tests/llms/anthropic/test_anthropic.py +++ b/lib/crewai/tests/llms/anthropic/test_anthropic.py @@ -197,7 +197,7 @@ def test_anthropic_specific_parameters(): from crewai.llm.providers.anthropic.completion import AnthropicCompletion assert isinstance(llm, AnthropicCompletion) - assert llm.stop_sequences == ["Human:", "Assistant:"] + assert llm.stop == ["Human:", "Assistant:"] assert llm.stream == True assert llm.client.max_retries == 5 assert llm.client.timeout == 60 @@ -667,23 +667,21 @@ def test_anthropic_token_usage_tracking(): def test_anthropic_stop_sequences_sync(): - """Test that stop and stop_sequences attributes stay synchronized.""" + """Test that stop sequences can be set and retrieved correctly.""" llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") # Test setting stop as a list llm.stop = ["\nObservation:", "\nThought:"] - assert llm.stop_sequences == ["\nObservation:", "\nThought:"] assert llm.stop == ["\nObservation:", "\nThought:"] - # Test setting stop as a string + # Test setting stop as a string - note: setting via attribute doesn't go through validator + # so it stays as a string llm.stop = "\nFinal Answer:" - assert llm.stop_sequences == ["\nFinal Answer:"] - assert llm.stop == ["\nFinal Answer:"] + assert llm.stop == "\nFinal Answer:" # Test setting stop as None llm.stop = None - assert llm.stop_sequences == [] - assert llm.stop == [] + assert llm.stop is None @pytest.mark.vcr(filter_headers=["authorization", "x-api-key"]) diff --git a/lib/crewai/tests/llms/bedrock/test_bedrock.py b/lib/crewai/tests/llms/bedrock/test_bedrock.py index 7ad7c20800..b3c12cdc2f 100644 --- a/lib/crewai/tests/llms/bedrock/test_bedrock.py +++ b/lib/crewai/tests/llms/bedrock/test_bedrock.py @@ -147,7 +147,7 @@ def test_bedrock_specific_parameters(): from crewai.llm.providers.bedrock.completion import BedrockCompletion assert isinstance(llm, BedrockCompletion) - assert llm.stop_sequences == ["Human:", "Assistant:"] + assert llm.stop == ["Human:", "Assistant:"] assert llm.stream == True assert llm.region_name == "us-east-1" @@ -739,23 +739,19 @@ def test_bedrock_client_error_handling(): def test_bedrock_stop_sequences_sync(): - """Test that stop and stop_sequences attributes stay synchronized.""" + """Test that stop sequences can be set and retrieved correctly.""" llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") # Test setting stop as a list llm.stop = ["\nObservation:", "\nThought:"] - assert list(llm.stop_sequences) == ["\nObservation:", "\nThought:"] assert llm.stop == ["\nObservation:", "\nThought:"] - # Test setting stop as a string - llm.stop = "\nFinal Answer:" - assert list(llm.stop_sequences) == ["\nFinal Answer:"] - assert llm.stop == ["\nFinal Answer:"] + llm2 = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", stop_sequences="\nFinal Answer:") + assert llm2.stop == ["\nFinal Answer:"] # Test setting stop as None llm.stop = None - assert list(llm.stop_sequences) == [] - assert llm.stop == [] + assert llm.stop is None def test_bedrock_stop_sequences_sent_to_api(): diff --git a/lib/crewai/tests/llms/google/test_google.py b/lib/crewai/tests/llms/google/test_google.py index ffb070b5e5..11af3e83bb 100644 --- a/lib/crewai/tests/llms/google/test_google.py +++ b/lib/crewai/tests/llms/google/test_google.py @@ -188,7 +188,7 @@ def test_gemini_specific_parameters(): from crewai.llm.providers.gemini.completion import GeminiCompletion assert isinstance(llm, GeminiCompletion) - assert llm.stop_sequences == ["Human:", "Assistant:"] + assert llm.stop == ["Human:", "Assistant:"] assert llm.stream == True assert llm.safety_settings == safety_settings assert llm.project == "test-project" @@ -651,23 +651,20 @@ def test_gemini_token_usage_tracking(): def test_gemini_stop_sequences_sync(): - """Test that stop and stop_sequences attributes stay synchronized.""" + """Test that stop sequences can be set and retrieved correctly.""" llm = LLM(model="google/gemini-2.0-flash-001") # Test setting stop as a list llm.stop = ["\nObservation:", "\nThought:"] - assert llm.stop_sequences == ["\nObservation:", "\nThought:"] assert llm.stop == ["\nObservation:", "\nThought:"] # Test setting stop as a string llm.stop = "\nFinal Answer:" - assert llm.stop_sequences == ["\nFinal Answer:"] - assert llm.stop == ["\nFinal Answer:"] + assert llm.stop == "\nFinal Answer:" # Test setting stop as None llm.stop = None - assert llm.stop_sequences == [] - assert llm.stop == [] + assert llm.stop is None def test_gemini_stop_sequences_sent_to_api(): From 08033180028b7c79888599dcad24d422e1958313 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Tue, 11 Nov 2025 17:37:08 -0500 Subject: [PATCH 5/7] chore: improve typing --- .../llm/providers/anthropic/completion.py | 26 ++++++++++++------- .../crewai/llm/providers/azure/completion.py | 24 +++++++++-------- .../llm/providers/bedrock/completion.py | 15 ++++++----- .../crewai/llm/providers/gemini/completion.py | 19 +++++++++----- .../crewai/llm/providers/openai/completion.py | 24 ++++++++--------- 5 files changed, 61 insertions(+), 47 deletions(-) diff --git a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py index 05366bbf4c..8678075d3d 100644 --- a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py @@ -3,9 +3,9 @@ import json import logging import os -from typing import Any, ClassVar, cast +from typing import TYPE_CHECKING, Any, cast -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType @@ -20,6 +20,11 @@ from crewai.utilities.types import LLMMessage +if TYPE_CHECKING: + from crewai.agent.core import Agent + from crewai.task import Task + + try: from anthropic import Anthropic from anthropic.types import Message @@ -48,6 +53,7 @@ class AnthropicCompletion(BaseLLM): client_params: Additional parameters for the Anthropic client interceptor: HTTP interceptor for modifying requests/responses at transport level """ + base_url: str | None = Field( default=None, description="Custom base URL for Anthropic API" ) @@ -121,8 +127,8 @@ def call( tools: list[dict[str, Any]] | None = None, callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call Anthropic messages API. @@ -311,8 +317,8 @@ def _handle_completion( self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Handle non-streaming message completion.""" @@ -399,8 +405,8 @@ def _handle_streaming_completion( self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str: """Handle streaming message completion.""" @@ -495,8 +501,8 @@ def _handle_tool_use_conversation( tool_uses: list[ToolUseBlock], params: dict[str, Any], available_functions: dict[str, Any], - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, ) -> str: """Handle the complete tool use conversation flow. diff --git a/lib/crewai/src/crewai/llm/providers/azure/completion.py b/lib/crewai/src/crewai/llm/providers/azure/completion.py index 3a4f68d085..8ad8eb7838 100644 --- a/lib/crewai/src/crewai/llm/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llm/providers/azure/completion.py @@ -3,9 +3,9 @@ import json import logging import os -from typing import TYPE_CHECKING, Any, ClassVar +from typing import TYPE_CHECKING, Any -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Self from crewai.llm.core import CONTEXT_WINDOW_USAGE_RATIO, LLM_CONTEXT_WINDOW_SIZES @@ -18,6 +18,8 @@ if TYPE_CHECKING: + from crewai.agent.core import Agent + from crewai.task import Task from crewai.tools.base_tool import BaseTool @@ -66,8 +68,6 @@ class AzureCompletion(BaseLLM): interceptor: HTTP interceptor (not yet supported for Azure) """ - model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True) - endpoint: str | None = Field( default=None, description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)", @@ -91,7 +91,9 @@ class AzureCompletion(BaseLLM): default=None, description="Maximum tokens in response" ) stream: bool = Field(default=False, description="Enable streaming responses") - client: Any = Field(default=None, exclude=True, description="Azure client instance") + _client: ChatCompletionsClient = PrivateAttr( + default_factory=ChatCompletionsClient, # type: ignore[arg-type] + ) _is_openai_model: bool = PrivateAttr(default=False) _is_azure_openai_endpoint: bool = PrivateAttr(default=False) @@ -190,8 +192,8 @@ def call( tools: list[dict[str, BaseTool]] | None = None, callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call Azure AI Inference chat completions API. @@ -382,8 +384,8 @@ def _handle_completion( self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Handle non-streaming chat completion.""" @@ -478,8 +480,8 @@ def _handle_streaming_completion( self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str: """Handle streaming chat completion.""" diff --git a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py index 58495a1518..935eb34327 100644 --- a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py @@ -32,6 +32,9 @@ ToolTypeDef, ) + from crewai.agent.core import Agent + from crewai.task import Task + try: from boto3.session import Session @@ -261,8 +264,8 @@ def call( tools: list[dict[Any, Any]] | None = None, callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call AWS Bedrock Converse API.""" @@ -347,8 +350,8 @@ def _handle_converse( messages: list[dict[str, Any]], body: BedrockConverseRequestBody, available_functions: Mapping[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, ) -> str: """Handle non-streaming converse API call following AWS best practices.""" try: @@ -528,8 +531,8 @@ def _handle_streaming_converse( messages: list[dict[str, Any]], body: BedrockConverseRequestBody, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, ) -> str: """Handle streaming converse API call with comprehensive event handling.""" full_response = "" diff --git a/lib/crewai/src/crewai/llm/providers/gemini/completion.py b/lib/crewai/src/crewai/llm/providers/gemini/completion.py index f2b1916560..5adfa3863f 100644 --- a/lib/crewai/src/crewai/llm/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llm/providers/gemini/completion.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, ClassVar, cast +from typing import TYPE_CHECKING, Any, ClassVar, cast from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from typing_extensions import Self @@ -16,6 +16,11 @@ from crewai.utilities.types import LLMMessage +if TYPE_CHECKING: + from crewai.agent.core import Agent + from crewai.task import Task + + try: from google import genai # type: ignore[import-untyped] from google.genai import types # type: ignore[import-untyped] @@ -196,8 +201,8 @@ def call( tools: list[dict[str, Any]] | None = None, callbacks: list[Any] | None = None, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Call Google Gemini generate content API. @@ -383,8 +388,8 @@ def _handle_completion( # type: ignore[no-any-unimported] system_instruction: str | None, config: types.GenerateContentConfig, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Handle non-streaming content generation.""" @@ -449,8 +454,8 @@ def _handle_streaming_completion( # type: ignore[no-any-unimported] contents: list[types.Content], config: types.GenerateContentConfig, available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str: """Handle streaming content generation.""" diff --git a/lib/crewai/src/crewai/llm/providers/openai/completion.py b/lib/crewai/src/crewai/llm/providers/openai/completion.py index b9fcc99c7c..d557f7dc5d 100644 --- a/lib/crewai/src/crewai/llm/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llm/providers/openai/completion.py @@ -11,7 +11,7 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion import Choice from openai.types.chat.chat_completion_chunk import ChoiceDelta -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType @@ -74,9 +74,7 @@ class OpenAICompletion(BaseLLM): ) reasoning_effort: str | None = Field(None, description="Reasoning effort level") - client: OpenAI = Field( - default_factory=OpenAI, exclude=True, description="OpenAI client instance" - ) + _client: OpenAI = PrivateAttr(default_factory=OpenAI) is_o1_model: bool = Field(False, description="Whether this is an O1 model") is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model") @@ -92,7 +90,7 @@ def setup_client(self) -> Self: http_client = httpx.Client(transport=transport) client_config["http_client"] = http_client - self.client = OpenAI(**client_config) + self._client = OpenAI(**client_config) self.is_o1_model = "o1" in self.model.lower() self.is_gpt4_model = "gpt-4" in self.model.lower() @@ -279,14 +277,14 @@ def _handle_completion( self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str | Any: """Handle non-streaming chat completion.""" try: if response_model: - parsed_response = self.client.beta.chat.completions.parse( + parsed_response = self._client.beta.chat.completions.parse( **params, response_format=response_model, ) @@ -310,7 +308,7 @@ def _handle_completion( ) return structured_json - response: ChatCompletion = self.client.chat.completions.create(**params) + response: ChatCompletion = self._client.chat.completions.create(**params) usage = self._extract_openai_token_usage(response) @@ -402,8 +400,8 @@ def _handle_streaming_completion( self, params: dict[str, Any], available_functions: dict[str, Any] | None = None, - from_task: Any | None = None, - from_agent: Any | None = None, + from_task: Task | None = None, + from_agent: Agent | None = None, response_model: type[BaseModel] | None = None, ) -> str: """Handle streaming chat completion.""" @@ -412,7 +410,7 @@ def _handle_streaming_completion( if response_model: completion_stream: Iterator[ChatCompletionChunk] = ( - self.client.chat.completions.create(**params) + self._client.chat.completions.create(**params) ) accumulated_content = "" @@ -455,7 +453,7 @@ def _handle_streaming_completion( ) return accumulated_content - stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create( + stream: Iterator[ChatCompletionChunk] = self._client.chat.completions.create( **params ) From 93f1fbd75efcc5b4d765b77b99329d415705c553 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Tue, 11 Nov 2025 17:46:26 -0500 Subject: [PATCH 6/7] chore: move api key validation to base --- lib/crewai/src/crewai/llm/base_llm.py | 19 +++++++++++++++++++ .../llm/providers/anthropic/completion.py | 7 ++----- .../crewai/llm/providers/azure/completion.py | 3 --- .../llm/providers/bedrock/completion.py | 8 ++------ .../crewai/llm/providers/openai/completion.py | 6 +----- 5 files changed, 24 insertions(+), 19 deletions(-) diff --git a/lib/crewai/src/crewai/llm/base_llm.py b/lib/crewai/src/crewai/llm/base_llm.py index e86fd817ba..9c26d59fcb 100644 --- a/lib/crewai/src/crewai/llm/base_llm.py +++ b/lib/crewai/src/crewai/llm/base_llm.py @@ -10,6 +10,7 @@ from datetime import datetime import json import logging +import os import re from typing import TYPE_CHECKING, Any, ClassVar, Final @@ -99,6 +100,24 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): "cached_prompt_tokens": 0, } + @field_validator("api_key", mode="before") + @classmethod + def _validate_api_key(cls, value: str | None) -> str | None: + """Validate API key for authentication. + + Args: + value: API key value or None + + Returns: + API key from environment if not provided, or the original value + """ + if value is None: + cls_name = cls.__name__ + provider_prefix = cls_name.replace("Completion", "").upper() + env_var = f"{provider_prefix}_API_KEY" + value = os.getenv(env_var) + return value + @field_validator("stop", mode="before") @classmethod def _normalize_stop(cls, value: Any) -> list[str]: diff --git a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py index 8678075d3d..b0e4b5c87a 100644 --- a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py @@ -2,9 +2,9 @@ import json import logging -import os from typing import TYPE_CHECKING, Any, cast +import httpx from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Self @@ -29,7 +29,6 @@ from anthropic import Anthropic from anthropic.types import Message from anthropic.types.tool_use_block import ToolUseBlock - import httpx except ImportError: raise ImportError( 'Anthropic native provider not available, to install: uv add "crewai[anthropic]"' @@ -100,9 +99,7 @@ def _get_client_params(self) -> dict[str, Any]: """Get client parameters.""" if self.api_key is None: - self.api_key = os.getenv("ANTHROPIC_API_KEY") - if self.api_key is None: - raise ValueError("ANTHROPIC_API_KEY is required") + raise ValueError("ANTHROPIC_API_KEY is required") client_params = { "api_key": self.api_key, diff --git a/lib/crewai/src/crewai/llm/providers/azure/completion.py b/lib/crewai/src/crewai/llm/providers/azure/completion.py index 8ad8eb7838..b30e9a2ba3 100644 --- a/lib/crewai/src/crewai/llm/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llm/providers/azure/completion.py @@ -107,9 +107,6 @@ def setup_client(self) -> Self: "Interceptors are currently supported for OpenAI and Anthropic providers only." ) - if self.api_key is None: - self.api_key = os.getenv("AZURE_API_KEY") - if self.endpoint is None: self.endpoint = ( os.getenv("AZURE_ENDPOINT") diff --git a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py index 935eb34327..282ea840d6 100644 --- a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py @@ -3,9 +3,9 @@ from collections.abc import Mapping, Sequence import logging import os -from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast +from typing import TYPE_CHECKING, Any, TypedDict, cast -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Required, Self from crewai.events.types.llm_events import LLMCallType @@ -161,10 +161,6 @@ class BedrockCompletion(BaseLLM): interceptor: HTTP interceptor (not yet supported for Bedrock) """ - model_config: ClassVar[ConfigDict] = ConfigDict( - ignored_types=(property,), arbitrary_types_allowed=True - ) - aws_access_key_id: str | None = Field( default=None, description="AWS access key (defaults to environment variable)" ) diff --git a/lib/crewai/src/crewai/llm/providers/openai/completion.py b/lib/crewai/src/crewai/llm/providers/openai/completion.py index d557f7dc5d..8a0143da66 100644 --- a/lib/crewai/src/crewai/llm/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llm/providers/openai/completion.py @@ -81,8 +81,6 @@ class OpenAICompletion(BaseLLM): @model_validator(mode="after") def setup_client(self) -> Self: """Initialize OpenAI client after model validation.""" - if self.api_key is None: - self.api_key = os.getenv("OPENAI_API_KEY") client_config = self._get_client_params() if self.interceptor: @@ -101,9 +99,7 @@ def _get_client_params(self) -> dict[str, Any]: """Get OpenAI client parameters.""" if self.api_key is None: - self.api_key = os.getenv("OPENAI_API_KEY") - if self.api_key is None: - raise ValueError("OPENAI_API_KEY is required") + raise ValueError("OPENAI_API_KEY is required") base_params = { "api_key": self.api_key, From 8b83bf3e54065fb66d95720dc4af392ab15ccae7 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Tue, 11 Nov 2025 18:01:27 -0500 Subject: [PATCH 7/7] chore: remove duplication in azure client --- lib/crewai/src/crewai/llm/base_llm.py | 58 ++--------------- .../src/crewai/llm/internal/constants.py | 14 ++++ lib/crewai/src/crewai/llm/internal/meta.py | 53 +++++++++++---- .../llm/providers/anthropic/completion.py | 24 +++---- .../crewai/llm/providers/azure/completion.py | 53 ++++++++------- .../llm/providers/bedrock/completion.py | 57 ++++++++++------- .../crewai/llm/providers/gemini/completion.py | 43 ++++++------- .../crewai/llm/providers/openai/completion.py | 64 +++++++++++-------- lib/crewai/src/crewai/llms/__init__.py | 39 ++++++++++- lib/crewai/src/crewai/llms/base_llm.py | 15 +++++ lib/crewai/src/crewai/llms/constants.py | 15 +++++ lib/crewai/src/crewai/llms/hooks/__init__.py | 15 +++++ lib/crewai/src/crewai/llms/hooks/base.py | 15 +++++ lib/crewai/src/crewai/llms/hooks/transport.py | 15 +++++ .../src/crewai/llms/internal/__init__.py | 15 +++++ .../src/crewai/llms/internal/constants.py | 15 +++++ .../src/crewai/llms/providers/__init__.py | 15 +++++ .../tests/llms/anthropic/test_anthropic.py | 12 ++-- lib/crewai/tests/llms/azure/test_azure.py | 12 ++-- lib/crewai/tests/llms/bedrock/test_bedrock.py | 10 +-- lib/crewai/tests/llms/google/test_google.py | 10 +-- lib/crewai/tests/llms/openai/test_openai.py | 10 +-- 22 files changed, 370 insertions(+), 209 deletions(-) create mode 100644 lib/crewai/src/crewai/llm/internal/constants.py create mode 100644 lib/crewai/src/crewai/llms/base_llm.py create mode 100644 lib/crewai/src/crewai/llms/constants.py create mode 100644 lib/crewai/src/crewai/llms/hooks/__init__.py create mode 100644 lib/crewai/src/crewai/llms/hooks/base.py create mode 100644 lib/crewai/src/crewai/llms/hooks/transport.py create mode 100644 lib/crewai/src/crewai/llms/internal/__init__.py create mode 100644 lib/crewai/src/crewai/llms/internal/constants.py create mode 100644 lib/crewai/src/crewai/llms/providers/__init__.py diff --git a/lib/crewai/src/crewai/llm/base_llm.py b/lib/crewai/src/crewai/llm/base_llm.py index 9c26d59fcb..3f46ad3277 100644 --- a/lib/crewai/src/crewai/llm/base_llm.py +++ b/lib/crewai/src/crewai/llm/base_llm.py @@ -12,10 +12,11 @@ import logging import os import re -from typing import TYPE_CHECKING, Any, ClassVar, Final +from typing import TYPE_CHECKING, Any, Final +from dotenv import load_dotenv import httpx -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator from crewai.events.event_bus import crewai_event_bus from crewai.events.types.llm_events import ( @@ -42,6 +43,8 @@ from crewai.utilities.types import LLMMessage +load_dotenv() + DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096 DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True _JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL) @@ -65,10 +68,6 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): stop: A list of stop sequences that the LLM should use to stop generation. """ - model_config: ClassVar[ConfigDict] = ConfigDict( - extra="allow", populate_by_name=True - ) - # Core fields model: str = Field(..., description="The model identifier/name") temperature: float | None = Field( @@ -100,7 +99,7 @@ class BaseLLM(BaseModel, ABC, metaclass=LLMMeta): "cached_prompt_tokens": 0, } - @field_validator("api_key", mode="before") + @field_validator("api_key", mode="after") @classmethod def _validate_api_key(cls, value: str | None) -> str | None: """Validate API key for authentication. @@ -137,37 +136,6 @@ def _normalize_stop(cls, value: Any) -> list[str]: return value return [] - @model_validator(mode="before") - @classmethod - def _extract_stop_and_validate(cls, values: dict[str, Any]) -> dict[str, Any]: - """Extract and normalize stop sequences before model initialization. - - Args: - values: Input values dictionary - - Returns: - Processed values dictionary - """ - if not values.get("model"): - raise ValueError("Model name is required and cannot be empty") - - stop = values.get("stop") or values.get("stop_sequences") - if stop is None: - values["stop"] = [] - elif isinstance(stop, str): - values["stop"] = [stop] - elif isinstance(stop, list): - values["stop"] = stop - else: - values["stop"] = [] - - values.pop("stop_sequences", None) - - if "provider" not in values or values["provider"] is None: - values["provider"] = "openai" - - return values - @property def additional_params(self) -> dict[str, Any]: """Get additional parameters stored as extra fields. @@ -190,20 +158,6 @@ def additional_params(self, value: dict[str, Any]) -> None: self.__pydantic_extra__ = {} self.__pydantic_extra__.update(value) - def model_post_init(self, __context: Any) -> None: - """Initialize token usage tracking after model initialization. - - Args: - __context: Pydantic context (unused) - """ - self._token_usage = { - "total_tokens": 0, - "prompt_tokens": 0, - "completion_tokens": 0, - "successful_requests": 0, - "cached_prompt_tokens": 0, - } - @abstractmethod def call( self, diff --git a/lib/crewai/src/crewai/llm/internal/constants.py b/lib/crewai/src/crewai/llm/internal/constants.py new file mode 100644 index 0000000000..1d24c4682e --- /dev/null +++ b/lib/crewai/src/crewai/llm/internal/constants.py @@ -0,0 +1,14 @@ +from crewai.llm.constants import SupportedNativeProviders + + +PROVIDER_MAPPING: dict[str, SupportedNativeProviders] = { + "openai": "openai", + "anthropic": "anthropic", + "claude": "anthropic", + "azure": "azure", + "azure_openai": "azure", + "google": "gemini", + "gemini": "gemini", + "bedrock": "bedrock", + "aws": "bedrock", +} diff --git a/lib/crewai/src/crewai/llm/internal/meta.py b/lib/crewai/src/crewai/llm/internal/meta.py index f210ab742c..8bfb74c246 100644 --- a/lib/crewai/src/crewai/llm/internal/meta.py +++ b/lib/crewai/src/crewai/llm/internal/meta.py @@ -9,6 +9,7 @@ import logging from typing import Any, cast +from pydantic import ConfigDict from pydantic._internal._model_construction import ModelMetaclass from crewai.llm.constants import ( @@ -21,6 +22,7 @@ SupportedModels, SupportedNativeProviders, ) +from crewai.llm.internal.constants import PROVIDER_MAPPING class LLMMeta(ModelMetaclass): @@ -30,6 +32,41 @@ class LLMMeta(ModelMetaclass): native provider implementation based on the model parameter. """ + def __new__( + mcs, + name: str, + bases: tuple[type, ...], + namespace: dict[str, Any], + **kwargs: Any, + ) -> type: + """Create new LLM class with proper model_config for custom LLMs. + + Args: + name: Class name + bases: Base classes + namespace: Class namespace + **kwargs: Additional arguments + + Returns: + New class + """ + if name != "BaseLLM" and any( + base.__name__ in ("BaseLLM", "LLM") for base in bases + ): + if "model_config" not in namespace: + namespace["model_config"] = ConfigDict( + extra="allow", populate_by_name=True + ) + elif isinstance(namespace["model_config"], dict): + config_dict = cast( + ConfigDict, cast(object, dict(namespace["model_config"])) + ) + config_dict.setdefault("extra", "allow") + config_dict.setdefault("populate_by_name", True) + namespace["model_config"] = ConfigDict(**config_dict) + + return super().__new__(mcs, name, bases, namespace) + def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805 """Route to appropriate provider implementation at instantiation time. @@ -57,7 +94,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805 if args and not kwargs.get("model"): kwargs["model"] = cast(SupportedModels, args[0]) - args = args[1:] + _ = args[1:] explicit_provider = cast(SupportedNativeProviders, kwargs.get("provider")) if explicit_provider: @@ -70,19 +107,7 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any: # noqa: N805 model.partition("/"), ) - provider_mapping: dict[str, SupportedNativeProviders] = { - "openai": "openai", - "anthropic": "anthropic", - "claude": "anthropic", - "azure": "azure", - "azure_openai": "azure", - "google": "gemini", - "gemini": "gemini", - "bedrock": "bedrock", - "aws": "bedrock", - } - - canonical_provider = provider_mapping.get(prefix.lower()) + canonical_provider = PROVIDER_MAPPING.get(prefix.lower()) if canonical_provider and cls._validate_model_in_constants( model_part, canonical_provider diff --git a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py index b0e4b5c87a..6303f4e4ce 100644 --- a/lib/crewai/src/crewai/llm/providers/anthropic/completion.py +++ b/lib/crewai/src/crewai/llm/providers/anthropic/completion.py @@ -4,6 +4,7 @@ import logging from typing import TYPE_CHECKING, Any, cast +from dotenv import load_dotenv import httpx from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Self @@ -21,13 +22,14 @@ if TYPE_CHECKING: + from anthropic.types import Message + from crewai.agent.core import Agent from crewai.task import Task try: from anthropic import Anthropic - from anthropic.types import Message from anthropic.types.tool_use_block import ToolUseBlock except ImportError: raise ImportError( @@ -35,6 +37,9 @@ ) from None +load_dotenv() + + class AnthropicCompletion(BaseLLM): """Anthropic native completion implementation. @@ -66,11 +71,9 @@ class AnthropicCompletion(BaseLLM): top_p: float | None = Field(default=None, description="Nucleus sampling parameter") stream: bool = Field(default=False, description="Enable streaming responses") client_params: dict[str, Any] | None = Field( - default=None, description="Additional Anthropic client parameters" - ) - client: Anthropic = Field( - default_factory=Anthropic, exclude=True, description="Anthropic client instance" + default_factory=dict, description="Additional Anthropic client parameters" ) + _client: Anthropic = PrivateAttr(default=None) # type: ignore[assignment] _is_claude_3: bool = PrivateAttr(default=False) _supports_tools: bool = PrivateAttr(default=False) @@ -78,7 +81,7 @@ class AnthropicCompletion(BaseLLM): @model_validator(mode="after") def setup_client(self) -> Self: """Initialize the Anthropic client and model-specific settings.""" - self.client = Anthropic(**self._get_client_params()) + self._client = Anthropic(**self._get_client_params()) self._is_claude_3 = "claude-3" in self.model.lower() self._supports_tools = self._is_claude_3 @@ -98,9 +101,6 @@ def supports_tools(self) -> bool: def _get_client_params(self) -> dict[str, Any]: """Get client parameters.""" - if self.api_key is None: - raise ValueError("ANTHROPIC_API_KEY is required") - client_params = { "api_key": self.api_key, "base_url": self.base_url, @@ -330,7 +330,7 @@ def _handle_completion( params["tool_choice"] = {"type": "tool", "name": "structured_output"} try: - response: Message = self.client.messages.create(**params) + response: Message = self._client.messages.create(**params) except Exception as e: if is_context_length_exceeded(e): @@ -424,7 +424,7 @@ def _handle_streaming_completion( stream_params = {k: v for k, v in params.items() if k != "stream"} # Make streaming API call - with self.client.messages.stream(**stream_params) as stream: + with self._client.messages.stream(**stream_params) as stream: for event in stream: if hasattr(event, "delta") and hasattr(event.delta, "text"): text_delta = event.delta.text @@ -552,7 +552,7 @@ def _handle_tool_use_conversation( try: # Send tool results back to Claude for final response - final_response: Message = self.client.messages.create(**follow_up_params) + final_response: Message = self._client.messages.create(**follow_up_params) # Track token usage for follow-up call follow_up_usage = self._extract_anthropic_token_usage(final_response) diff --git a/lib/crewai/src/crewai/llm/providers/azure/completion.py b/lib/crewai/src/crewai/llm/providers/azure/completion.py index b30e9a2ba3..9963fee6f2 100644 --- a/lib/crewai/src/crewai/llm/providers/azure/completion.py +++ b/lib/crewai/src/crewai/llm/providers/azure/completion.py @@ -5,6 +5,7 @@ import os from typing import TYPE_CHECKING, Any +from dotenv import load_dotenv from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Self @@ -48,6 +49,9 @@ ) from None +load_dotenv() + + class AzureCompletion(BaseLLM): """Azure AI Inference native completion implementation. @@ -68,12 +72,14 @@ class AzureCompletion(BaseLLM): interceptor: HTTP interceptor (not yet supported for Azure) """ - endpoint: str | None = Field( - default=None, + endpoint: str = Field( # type: ignore[assignment] + default_factory=lambda: os.getenv("AZURE_ENDPOINT") + or os.getenv("AZURE_OPENAI_ENDPOINT") + or os.getenv("AZURE_API_BASE"), description="Azure endpoint URL (defaults to AZURE_ENDPOINT env var)", ) api_version: str = Field( - default="2024-06-01", + default_factory=lambda: os.getenv("AZURE_API_VERSION", "2024-06-01"), description="Azure API version (defaults to AZURE_API_VERSION env var or 2024-06-01)", ) timeout: float | None = Field( @@ -82,18 +88,16 @@ class AzureCompletion(BaseLLM): max_retries: int = Field(default=2, description="Maximum number of retries") top_p: float | None = Field(default=None, description="Nucleus sampling parameter") frequency_penalty: float | None = Field( - default=None, description="Frequency penalty (-2 to 2)" + default=None, le=2.0, ge=-2.0, description="Frequency penalty (-2 to 2)" ) presence_penalty: float | None = Field( - default=None, description="Presence penalty (-2 to 2)" + default=None, le=2.0, ge=-2.0, description="Presence penalty (-2 to 2)" ) max_tokens: int | None = Field( default=None, description="Maximum tokens in response" ) stream: bool = Field(default=False, description="Enable streaming responses") - _client: ChatCompletionsClient = PrivateAttr( - default_factory=ChatCompletionsClient, # type: ignore[arg-type] - ) + _client: ChatCompletionsClient = PrivateAttr(default=None) # type: ignore[assignment] _is_openai_model: bool = PrivateAttr(default=False) _is_azure_openai_endpoint: bool = PrivateAttr(default=False) @@ -107,26 +111,13 @@ def setup_client(self) -> Self: "Interceptors are currently supported for OpenAI and Anthropic providers only." ) - if self.endpoint is None: - self.endpoint = ( - os.getenv("AZURE_ENDPOINT") - or os.getenv("AZURE_OPENAI_ENDPOINT") - or os.getenv("AZURE_API_BASE") - ) - - if self.api_version == "2024-06-01": - env_version = os.getenv("AZURE_API_VERSION") - if env_version: - self.api_version = env_version + if not self.api_key: + self.api_key = os.getenv("AZURE_API_KEY") if not self.api_key: raise ValueError( "Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter." ) - if not self.endpoint: - raise ValueError( - "Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter." - ) self.endpoint = self._validate_and_fix_endpoint(self.endpoint, self.model) @@ -138,7 +129,7 @@ def setup_client(self) -> Self: if self.api_version: client_kwargs["api_version"] = self.api_version - self.client = ChatCompletionsClient(**client_kwargs) + self._client = ChatCompletionsClient(**client_kwargs) self._is_openai_model = any( prefix in self.model.lower() for prefix in ["gpt-", "o1-", "text-"] @@ -160,7 +151,7 @@ def is_azure_openai_endpoint(self) -> bool: """Check if endpoint is an Azure OpenAI endpoint.""" return self._is_azure_openai_endpoint - def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str: + def _validate_and_fix_endpoint(self, endpoint: str | None, model: str) -> str: """Validate and fix Azure endpoint URL format. Azure OpenAI endpoints should be in the format: @@ -172,7 +163,15 @@ def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str: Returns: Validated and potentially corrected endpoint URL + + Raises: + ValueError: If endpoint is None or empty """ + if not endpoint: + raise ValueError( + "Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter." + ) + if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint: endpoint = endpoint.rstrip("/") @@ -388,7 +387,7 @@ def _handle_completion( """Handle non-streaming chat completion.""" # Make API call try: - response: ChatCompletions = self.client.complete(**params) + response: ChatCompletions = self._client.complete(**params) if not response.choices: raise ValueError("No choices returned from Azure API") @@ -486,7 +485,7 @@ def _handle_streaming_completion( tool_calls = {} # Make streaming API call - for update in self.client.complete(**params): + for update in self._client.complete(**params): if isinstance(update, StreamingChatCompletionsUpdate): if update.choices: choice = update.choices[0] diff --git a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py index 282ea840d6..62afae1035 100644 --- a/lib/crewai/src/crewai/llm/providers/bedrock/completion.py +++ b/lib/crewai/src/crewai/llm/providers/bedrock/completion.py @@ -5,6 +5,8 @@ import os from typing import TYPE_CHECKING, Any, TypedDict, cast +from dotenv import load_dotenv +from mypy_boto3_bedrock_runtime.client import BedrockRuntimeClient from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Required, Self @@ -75,6 +77,9 @@ class EnhancedInferenceConfigurationTypeDef(TypedDict, total=False): topK: int +load_dotenv() + + class ToolInputSchema(TypedDict): """Type definition for tool input schema in Converse API.""" @@ -161,16 +166,22 @@ class BedrockCompletion(BaseLLM): interceptor: HTTP interceptor (not yet supported for Bedrock) """ - aws_access_key_id: str | None = Field( - default=None, description="AWS access key (defaults to environment variable)" + aws_access_key_id: str = Field( # type: ignore[assignment] + default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"), + description="AWS access key (defaults to environment variable)", ) - aws_secret_access_key: str | None = Field( - default=None, description="AWS secret key (defaults to environment variable)" + aws_secret_access_key: str = Field( # type: ignore[assignment] + default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"), + description="AWS secret key (defaults to environment variable)", ) - aws_session_token: str | None = Field( - default=None, description="AWS session token for temporary credentials" + aws_session_token: str = Field( # type: ignore[assignment] + default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"), + description="AWS session token for temporary credentials", + ) + region_name: str = Field( + default_factory=lambda: os.getenv("AWS_REGION", "us-east-1"), + description="AWS region name", ) - region_name: str = Field(default="us-east-1", description="AWS region name") max_tokens: int | None = Field( default=None, description="Maximum tokens to generate" ) @@ -181,17 +192,18 @@ class BedrockCompletion(BaseLLM): stream: bool = Field( default=False, description="Whether to use streaming responses" ) - guardrail_config: dict[str, Any] | None = Field( - default=None, description="Guardrail configuration for content filtering" + guardrail_config: dict[str, Any] = Field( + default_factory=dict, + description="Guardrail configuration for content filtering", ) - additional_model_request_fields: dict[str, Any] | None = Field( - default=None, description="Model-specific request parameters" + additional_model_request_fields: dict[str, Any] = Field( + default_factory=dict, description="Model-specific request parameters" ) - additional_model_response_field_paths: list[str] | None = Field( - default=None, description="Custom response field paths" + additional_model_response_field_paths: list[str] = Field( + default_factory=list, description="Custom response field paths" ) - client: Any = Field( - default=None, exclude=True, description="Bedrock client instance" + _client: BedrockRuntimeClient = PrivateAttr( # type: ignore[assignment] + default_factory=lambda: Session().client, ) _is_claude_model: bool = PrivateAttr(default=False) @@ -209,10 +221,9 @@ def setup_client(self) -> Self: ) session = Session( - aws_access_key_id=self.aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID"), - aws_secret_access_key=self.aws_secret_access_key - or os.getenv("AWS_SECRET_ACCESS_KEY"), - aws_session_token=self.aws_session_token or os.getenv("AWS_SESSION_TOKEN"), + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + aws_session_token=self.aws_session_token, region_name=self.region_name, ) @@ -225,7 +236,7 @@ def setup_client(self) -> Self: tcp_keepalive=True, ) - self.client = session.client("bedrock-runtime", config=config) + self._client = session.client("bedrock-runtime", config=config) self._is_claude_model = "claude" in self.model.lower() self._supports_tools = True @@ -365,7 +376,7 @@ def _handle_converse( raise ValueError(f"Invalid message format at index {i}") # Call Bedrock Converse API with proper error handling - response = self.client.converse( + response = self._client.converse( modelId=self.model_id, messages=cast( "Sequence[MessageTypeDef | MessageOutputTypeDef]", @@ -536,13 +547,13 @@ def _handle_streaming_converse( tool_use_id = None try: - response = self.client.converse_stream( + response = self._client.converse_stream( modelId=self.model_id, messages=cast( "Sequence[MessageTypeDef | MessageOutputTypeDef]", cast(object, messages), ), - **body, + **body, # type: ignore[arg-type] ) stream = response.get("stream") diff --git a/lib/crewai/src/crewai/llm/providers/gemini/completion.py b/lib/crewai/src/crewai/llm/providers/gemini/completion.py index 5adfa3863f..38321d0539 100644 --- a/lib/crewai/src/crewai/llm/providers/gemini/completion.py +++ b/lib/crewai/src/crewai/llm/providers/gemini/completion.py @@ -1,8 +1,11 @@ +from __future__ import annotations + import logging import os -from typing import TYPE_CHECKING, Any, ClassVar, cast +from typing import TYPE_CHECKING, Any, cast -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator +from dotenv import load_dotenv +from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing_extensions import Self from crewai.events.types.llm_events import LLMCallType @@ -31,6 +34,9 @@ ) from None +load_dotenv() + + class GeminiCompletion(BaseLLM): """Google Gemini native completion implementation. @@ -50,15 +56,12 @@ class GeminiCompletion(BaseLLM): interceptor: HTTP interceptor (not yet supported for Gemini) """ - model_config: ClassVar[ConfigDict] = ConfigDict( - ignored_types=(property,), arbitrary_types_allowed=True - ) - project: str | None = Field( - default=None, description="Google Cloud project ID (for Vertex AI)" + default_factory=lambda: os.getenv("GOOGLE_CLOUD_PROJECT"), + description="Google Cloud project ID (for Vertex AI)", ) location: str = Field( - default="us-central1", + default_factory=lambda: os.getenv("GOOGLE_CLOUD_LOCATION", "us-central1"), description="Google Cloud location (for Vertex AI, defaults to 'us-central1')", ) top_p: float | None = Field(default=None, description="Nucleus sampling parameter") @@ -74,9 +77,7 @@ class GeminiCompletion(BaseLLM): default_factory=dict, description="Additional parameters for Google Gen AI Client constructor", ) - client: Any = Field( - default=None, exclude=True, description="Gemini client instance" - ) + _client: Any = PrivateAttr(default=None) _is_gemini_2: bool = PrivateAttr(default=False) _is_gemini_1_5: bool = PrivateAttr(default=False) @@ -94,17 +95,9 @@ def setup_client(self) -> Self: if self.api_key is None: self.api_key = os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY") - if self.project is None: - self.project = os.getenv("GOOGLE_CLOUD_PROJECT") - - if self.location == "us-central1": - env_location = os.getenv("GOOGLE_CLOUD_LOCATION") - if env_location: - self.location = env_location - use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true" - self.client = self._initialize_client(use_vertexai) + self._client = self._initialize_client(use_vertexai) self._is_gemini_2 = "gemini-2" in self.model.lower() self._is_gemini_1_5 = "gemini-1.5" in self.model.lower() @@ -176,9 +169,9 @@ def _get_client_params(self) -> dict[str, Any]: params = {} if ( - hasattr(self, "client") - and hasattr(self.client, "vertexai") - and self.client.vertexai + hasattr(self, "_client") + and hasattr(self._client, "vertexai") + and self._client.vertexai ): params.update( { @@ -400,7 +393,7 @@ def _handle_completion( # type: ignore[no-any-unimported] } try: - response = self.client.models.generate_content(**api_params) + response = self._client.models.generate_content(**api_params) usage = self._extract_token_usage(response) except Exception as e: @@ -468,7 +461,7 @@ def _handle_streaming_completion( # type: ignore[no-any-unimported] "config": config, } - for chunk in self.client.models.generate_content_stream(**api_params): + for chunk in self._client.models.generate_content_stream(**api_params): if hasattr(chunk, "text") and chunk.text: full_response += chunk.text self._emit_stream_chunk_event( diff --git a/lib/crewai/src/crewai/llm/providers/openai/completion.py b/lib/crewai/src/crewai/llm/providers/openai/completion.py index 8a0143da66..7d5dd1ec77 100644 --- a/lib/crewai/src/crewai/llm/providers/openai/completion.py +++ b/lib/crewai/src/crewai/llm/providers/openai/completion.py @@ -6,6 +6,7 @@ import os from typing import TYPE_CHECKING, Any +from dotenv import load_dotenv import httpx from openai import APIConnectionError, NotFoundError, OpenAI from openai.types.chat import ChatCompletion, ChatCompletionChunk @@ -32,6 +33,9 @@ from crewai.tools.base_tool import BaseTool +load_dotenv() + + class OpenAICompletion(BaseLLM): """OpenAI native completion implementation. @@ -40,43 +44,51 @@ class OpenAICompletion(BaseLLM): """ # Client configuration fields - organization: str | None = Field(None, description="OpenAI organization ID") - project: str | None = Field(None, description="OpenAI project ID") - max_retries: int = Field(2, description="Maximum number of retries") - default_headers: dict[str, str] | None = Field( - None, description="Default headers for requests" + organization: str | None = Field(default=None, description="OpenAI organization ID") + project: str | None = Field(default=None, description="OpenAI project ID") + max_retries: int = Field(default=2, description="Maximum number of retries") + default_headers: dict[str, str] = Field( + default_factory=dict, description="Default headers for requests" ) - default_query: dict[str, Any] | None = Field( - None, description="Default query parameters" + default_query: dict[str, Any] = Field( + default_factory=dict, description="Default query parameters" ) - client_params: dict[str, Any] | None = Field( - None, description="Additional client parameters" + client_params: dict[str, Any] = Field( + default_factory=dict, description="Additional client parameters" + ) + timeout: float | None = Field(default=None, description="Request timeout") + api_base: str | None = Field( + default=None, description="API base URL", deprecated=True ) - timeout: float | None = Field(None, description="Request timeout") - api_base: str | None = Field(None, description="API base URL (deprecated)") # Completion parameters - top_p: float | None = Field(None, description="Top-p sampling parameter") - frequency_penalty: float | None = Field(None, description="Frequency penalty") - presence_penalty: float | None = Field(None, description="Presence penalty") - max_tokens: int | None = Field(None, description="Maximum tokens") + top_p: float | None = Field(default=None, description="Top-p sampling parameter") + frequency_penalty: float | None = Field( + default=None, description="Frequency penalty" + ) + presence_penalty: float | None = Field(default=None, description="Presence penalty") + max_tokens: int | None = Field(default=None, description="Maximum tokens") max_completion_tokens: int | None = Field( None, description="Maximum completion tokens" ) - seed: int | None = Field(None, description="Random seed") - stream: bool = Field(False, description="Enable streaming") + seed: int | None = Field(default=None, description="Random seed") + stream: bool = Field(default=False, description="Enable streaming") response_format: dict[str, Any] | type[BaseModel] | None = Field( - None, description="Response format" + default=None, description="Response format" ) - logprobs: bool | None = Field(None, description="Return log probabilities") + logprobs: bool | None = Field(default=None, description="Return log probabilities") top_logprobs: int | None = Field( - None, description="Number of top log probabilities" + default=None, description="Number of top log probabilities" + ) + reasoning_effort: str | None = Field( + default=None, description="Reasoning effort level" ) - reasoning_effort: str | None = Field(None, description="Reasoning effort level") - _client: OpenAI = PrivateAttr(default_factory=OpenAI) - is_o1_model: bool = Field(False, description="Whether this is an O1 model") - is_gpt4_model: bool = Field(False, description="Whether this is a GPT-4 model") + _client: OpenAI = PrivateAttr(default=None) # type: ignore[assignment] + is_o1_model: bool = Field(default=False, description="Whether this is an O1 model") + is_gpt4_model: bool = Field( + default=False, description="Whether this is a GPT-4 model" + ) @model_validator(mode="after") def setup_client(self) -> Self: @@ -97,10 +109,6 @@ def setup_client(self) -> Self: def _get_client_params(self) -> dict[str, Any]: """Get OpenAI client parameters.""" - - if self.api_key is None: - raise ValueError("OPENAI_API_KEY is required") - base_params = { "api_key": self.api_key, "organization": self.organization, diff --git a/lib/crewai/src/crewai/llms/__init__.py b/lib/crewai/src/crewai/llms/__init__.py index fda1e6a3be..9ffac98a55 100644 --- a/lib/crewai/src/crewai/llms/__init__.py +++ b/lib/crewai/src/crewai/llms/__init__.py @@ -1 +1,38 @@ -"""LLM implementations for crewAI.""" +"""LLM implementations for crewAI. + +.. deprecated:: 1.4.0 + The `crewai.llms` package is deprecated. Use `crewai.llm` instead. + + This package was reorganized from `crewai.llms.*` to `crewai.llm.*`. + All submodules are redirected to their new locations in `crewai.llm.*`. + + Migration guide: + Old: from crewai.llms.base_llm import BaseLLM + New: from crewai.llm.base_llm import BaseLLM + + Old: from crewai.llms.hooks.base import BaseInterceptor + New: from crewai.llm.hooks.base import BaseInterceptor + + Old: from crewai.llms.constants import OPENAI_MODELS + New: from crewai.llm.constants import OPENAI_MODELS + + Or use top-level imports: + from crewai import LLM, BaseLLM +""" + +import warnings + +from crewai.llm import LLM +from crewai.llm.base_llm import BaseLLM + + +# Issue deprecation warning when this module is imported +warnings.warn( + "The 'crewai.llms' package is deprecated and will be removed in a future version. " + "Please use 'crewai.llm' (singular) instead. " + "All submodules have been reorganized from 'crewai.llms.*' to 'crewai.llm.*'.", + DeprecationWarning, + stacklevel=2, +) + +__all__ = ["LLM", "BaseLLM"] diff --git a/lib/crewai/src/crewai/llms/base_llm.py b/lib/crewai/src/crewai/llms/base_llm.py new file mode 100644 index 0000000000..0eef37033e --- /dev/null +++ b/lib/crewai/src/crewai/llms/base_llm.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.base_llm instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.base_llm is deprecated. Use crewai.llm.base_llm instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.base_llm import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/llms/constants.py b/lib/crewai/src/crewai/llms/constants.py new file mode 100644 index 0000000000..8dc310b0ad --- /dev/null +++ b/lib/crewai/src/crewai/llms/constants.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.constants instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.constants is deprecated. Use crewai.llm.constants instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.constants import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/llms/hooks/__init__.py b/lib/crewai/src/crewai/llms/hooks/__init__.py new file mode 100644 index 0000000000..c63684cd7a --- /dev/null +++ b/lib/crewai/src/crewai/llms/hooks/__init__.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.hooks instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.hooks is deprecated. Use crewai.llm.hooks instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.hooks import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/llms/hooks/base.py b/lib/crewai/src/crewai/llms/hooks/base.py new file mode 100644 index 0000000000..7149e70f71 --- /dev/null +++ b/lib/crewai/src/crewai/llms/hooks/base.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.hooks.base instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.hooks.base is deprecated. Use crewai.llm.hooks.base instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.hooks.base import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/llms/hooks/transport.py b/lib/crewai/src/crewai/llms/hooks/transport.py new file mode 100644 index 0000000000..8ec3bc65ec --- /dev/null +++ b/lib/crewai/src/crewai/llms/hooks/transport.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.hooks.transport instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.hooks.transport is deprecated. Use crewai.llm.hooks.transport instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.hooks.transport import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/llms/internal/__init__.py b/lib/crewai/src/crewai/llms/internal/__init__.py new file mode 100644 index 0000000000..cca43464b1 --- /dev/null +++ b/lib/crewai/src/crewai/llms/internal/__init__.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.internal instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.internal is deprecated. Use crewai.llm.internal instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.internal import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/llms/internal/constants.py b/lib/crewai/src/crewai/llms/internal/constants.py new file mode 100644 index 0000000000..5fe8c439cb --- /dev/null +++ b/lib/crewai/src/crewai/llms/internal/constants.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.internal.constants instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.internal.constants is deprecated. Use crewai.llm.internal.constants instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.internal.constants import * # noqa: E402, F403 diff --git a/lib/crewai/src/crewai/llms/providers/__init__.py b/lib/crewai/src/crewai/llms/providers/__init__.py new file mode 100644 index 0000000000..95bc6d4482 --- /dev/null +++ b/lib/crewai/src/crewai/llms/providers/__init__.py @@ -0,0 +1,15 @@ +"""Deprecated: Use crewai.llm.providers instead. + +.. deprecated:: 1.4.0 +""" + +import warnings + + +warnings.warn( + "crewai.llms.providers is deprecated. Use crewai.llm.providers instead.", + DeprecationWarning, + stacklevel=2, +) + +from crewai.llm.providers import * # noqa: E402, F403 diff --git a/lib/crewai/tests/llms/anthropic/test_anthropic.py b/lib/crewai/tests/llms/anthropic/test_anthropic.py index 5a91b2e1e3..2cdbe1d496 100644 --- a/lib/crewai/tests/llms/anthropic/test_anthropic.py +++ b/lib/crewai/tests/llms/anthropic/test_anthropic.py @@ -60,7 +60,7 @@ def mock_weather_tool(location: str) -> str: available_functions = {"get_weather": mock_weather_tool} # Mock the Anthropic client responses - with patch.object(completion.client.messages, 'create') as mock_create: + with patch.object(completion._client.messages, 'create') as mock_create: # Mock initial response with tool use - need to properly mock ToolUseBlock mock_tool_use = Mock(spec=ToolUseBlock) mock_tool_use.id = "tool_123" @@ -199,8 +199,8 @@ def test_anthropic_specific_parameters(): assert isinstance(llm, AnthropicCompletion) assert llm.stop == ["Human:", "Assistant:"] assert llm.stream == True - assert llm.client.max_retries == 5 - assert llm.client.timeout == 60 + assert llm._client.max_retries == 5 + assert llm._client.timeout == 60 def test_anthropic_completion_call(): @@ -637,8 +637,8 @@ def test_anthropic_environment_variable_api_key(): with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-anthropic-key"}): llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") - assert llm.client is not None - assert hasattr(llm.client, 'messages') + assert llm._client is not None + assert hasattr(llm._client, 'messages') def test_anthropic_token_usage_tracking(): @@ -648,7 +648,7 @@ def test_anthropic_token_usage_tracking(): llm = LLM(model="anthropic/claude-3-5-sonnet-20241022") # Mock the Anthropic response with usage information - with patch.object(llm.client.messages, 'create') as mock_create: + with patch.object(llm._client.messages, 'create') as mock_create: mock_response = MagicMock() mock_response.content = [MagicMock(text="test response")] mock_response.usage = MagicMock(input_tokens=50, output_tokens=25) diff --git a/lib/crewai/tests/llms/azure/test_azure.py b/lib/crewai/tests/llms/azure/test_azure.py index 6cc4eb463a..055c9d4994 100644 --- a/lib/crewai/tests/llms/azure/test_azure.py +++ b/lib/crewai/tests/llms/azure/test_azure.py @@ -64,7 +64,7 @@ def mock_weather_tool(location: str) -> str: available_functions = {"get_weather": mock_weather_tool} # Mock the Azure client responses - with patch.object(completion.client, 'complete') as mock_complete: + with patch.object(completion._client, 'complete') as mock_complete: # Mock tool call in response with proper type mock_tool_call = MagicMock(spec=ChatCompletionsToolCall) mock_tool_call.function.name = "get_weather" @@ -602,7 +602,7 @@ def test_azure_environment_variable_endpoint(): }): llm = LLM(model="azure/gpt-4") - assert llm.client is not None + assert llm._client is not None assert llm.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4" @@ -613,7 +613,7 @@ def test_azure_token_usage_tracking(): llm = LLM(model="azure/gpt-4") # Mock the Azure response with usage information - with patch.object(llm.client, 'complete') as mock_complete: + with patch.object(llm._client, 'complete') as mock_complete: mock_message = MagicMock() mock_message.content = "test response" mock_message.tool_calls = None @@ -651,7 +651,7 @@ def test_azure_http_error_handling(): llm = LLM(model="azure/gpt-4") # Mock an HTTP error - with patch.object(llm.client, 'complete') as mock_complete: + with patch.object(llm._client, 'complete') as mock_complete: mock_complete.side_effect = HttpResponseError(message="Rate limit exceeded", response=MagicMock(status_code=429)) with pytest.raises(HttpResponseError): @@ -668,7 +668,7 @@ def test_azure_streaming_completion(): llm = LLM(model="azure/gpt-4", stream=True) # Mock streaming response - with patch.object(llm.client, 'complete') as mock_complete: + with patch.object(llm._client, 'complete') as mock_complete: # Create mock streaming updates with proper type mock_updates = [] for chunk in ["Hello", " ", "world", "!"]: @@ -891,7 +891,7 @@ def test_azure_improved_error_messages(): llm = LLM(model="azure/gpt-4") - with patch.object(llm.client, 'complete') as mock_complete: + with patch.object(llm._client, 'complete') as mock_complete: error_401 = HttpResponseError(message="Unauthorized") error_401.status_code = 401 mock_complete.side_effect = error_401 diff --git a/lib/crewai/tests/llms/bedrock/test_bedrock.py b/lib/crewai/tests/llms/bedrock/test_bedrock.py index b3c12cdc2f..130e498909 100644 --- a/lib/crewai/tests/llms/bedrock/test_bedrock.py +++ b/lib/crewai/tests/llms/bedrock/test_bedrock.py @@ -579,7 +579,7 @@ def test_bedrock_token_usage_tracking(): llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") # Mock the Bedrock response with usage information - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: mock_response = { 'output': { 'message': { @@ -624,7 +624,7 @@ def mock_weather_tool(location: str) -> str: available_functions = {"get_weather": mock_weather_tool} # Mock the Bedrock client responses - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: # First response: tool use request tool_use_response = { 'output': { @@ -710,7 +710,7 @@ def test_bedrock_client_error_handling(): llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0") # Test ValidationException - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: error_response = { 'Error': { 'Code': 'ValidationException', @@ -724,7 +724,7 @@ def test_bedrock_client_error_handling(): assert "validation" in str(exc_info.value).lower() # Test ThrottlingException - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: error_response = { 'Error': { 'Code': 'ThrottlingException', @@ -762,7 +762,7 @@ def test_bedrock_stop_sequences_sent_to_api(): llm.stop = ["\nObservation:", "\nThought:"] # Patch the API call to capture parameters without making real call - with patch.object(llm.client, 'converse') as mock_converse: + with patch.object(llm._client, 'converse') as mock_converse: mock_response = { 'output': { 'message': { diff --git a/lib/crewai/tests/llms/google/test_google.py b/lib/crewai/tests/llms/google/test_google.py index 11af3e83bb..ffcc2a978e 100644 --- a/lib/crewai/tests/llms/google/test_google.py +++ b/lib/crewai/tests/llms/google/test_google.py @@ -59,7 +59,7 @@ def mock_weather_tool(location: str) -> str: available_functions = {"get_weather": mock_weather_tool} # Mock the Google Gemini client responses - with patch.object(completion.client.models, 'generate_content') as mock_generate: + with patch.object(completion._client.models, 'generate_content') as mock_generate: # Mock function call in response mock_function_call = Mock() mock_function_call.name = "get_weather" @@ -614,8 +614,8 @@ def test_gemini_environment_variable_api_key(): with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}): llm = LLM(model="google/gemini-2.0-flash-001") - assert llm.client is not None - assert hasattr(llm.client, 'models') + assert llm._client is not None + assert hasattr(llm._client, 'models') assert llm.api_key == "test-google-key" @@ -626,7 +626,7 @@ def test_gemini_token_usage_tracking(): llm = LLM(model="google/gemini-2.0-flash-001") # Mock the Gemini response with usage information - with patch.object(llm.client.models, 'generate_content') as mock_generate: + with patch.object(llm._client.models, 'generate_content') as mock_generate: mock_response = MagicMock() mock_response.text = "test response" mock_response.candidates = [] @@ -675,7 +675,7 @@ def test_gemini_stop_sequences_sent_to_api(): llm.stop = ["\nObservation:", "\nThought:"] # Patch the API call to capture parameters without making real call - with patch.object(llm.client.models, 'generate_content') as mock_generate: + with patch.object(llm._client.models, 'generate_content') as mock_generate: mock_response = MagicMock() mock_response.text = "Hello" mock_response.candidates = [] diff --git a/lib/crewai/tests/llms/openai/test_openai.py b/lib/crewai/tests/llms/openai/test_openai.py index aee167ab57..d115eb6ebf 100644 --- a/lib/crewai/tests/llms/openai/test_openai.py +++ b/lib/crewai/tests/llms/openai/test_openai.py @@ -369,11 +369,11 @@ def test_openai_client_setup_with_extra_arguments(): assert llm.top_p == 0.5 # Check that client parameters are properly configured - assert llm.client.max_retries == 3 - assert llm.client.timeout == 30 + assert llm._client.max_retries == 3 + assert llm._client.timeout == 30 # Test that parameters are properly used in API calls - with patch.object(llm.client.chat.completions, 'create') as mock_create: + with patch.object(llm._client.chat.completions, 'create') as mock_create: mock_create.return_value = MagicMock( choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))], usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30) @@ -394,7 +394,7 @@ def test_extra_arguments_are_passed_to_openai_completion(): """ llm = LLM(model="gpt-4o", temperature=0.7, max_tokens=1000, top_p=0.5, max_retries=3) - with patch.object(llm.client.chat.completions, 'create') as mock_create: + with patch.object(llm._client.chat.completions, 'create') as mock_create: mock_create.return_value = MagicMock( choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))], usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30) @@ -501,7 +501,7 @@ class TestResponse(BaseModel): llm = LLM(model="openai/gpt-4o", stream=True) - with patch.object(llm.client.chat.completions, "create") as mock_create: + with patch.object(llm._client.chat.completions, "create") as mock_create: mock_chunk1 = MagicMock() mock_chunk1.choices = [ MagicMock(delta=MagicMock(content='{"answer": "test", ', tool_calls=None))