Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/crewai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,7 @@ async def kickoff_async(
i18n=self.i18n,
original_agent=self,
guardrail=self.guardrail,
guardrail_max_retries=self.guardrail_max_retries,
)

return await lite_agent.kickoff_async(messages)
154 changes: 114 additions & 40 deletions src/crewai/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import threading
import uuid
import warnings
from collections.abc import Callable
from collections.abc import Callable, Sequence
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
Expand Down Expand Up @@ -152,6 +152,15 @@ class Task(BaseModel):
default=None,
description="Function or string description of a guardrail to validate task output before proceeding to next task",
)
guardrails: (
Sequence[Callable[[TaskOutput], tuple[bool, Any]] | str]
| Callable[[TaskOutput], tuple[bool, Any]]
| str
| None
) = Field(
default=None,
description="List of guardrails to validate task output before proceeding to next task. Also supports a single guardrail function or string description of a guardrail to validate task output before proceeding to next task",
)
max_retries: int | None = Field(
default=None,
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0",
Expand Down Expand Up @@ -268,6 +277,44 @@ def ensure_guardrail_is_callable(self) -> "Task":

return self

@model_validator(mode="after")
def ensure_guardrails_is_list_of_callables(self) -> "Task":
guardrails = []
if self.guardrails is not None and (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you want to check for empty lists here?

not isinstance(self.guardrails, (list, tuple)) or len(self.guardrails) > 0
):
if self.agent is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the agent check is too early maybe, only string-based guardrails need an agent, right?

raise ValueError("Agent is required to use guardrails")

if callable(self.guardrails):
guardrails.append(self.guardrails)
elif isinstance(self.guardrails, str):
from crewai.tasks.llm_guardrail import LLMGuardrail

guardrails.append(
LLMGuardrail(description=self.guardrails, llm=self.agent.llm)
)

if isinstance(self.guardrails, list):
for guardrail in self.guardrails:
if callable(guardrail):
guardrails.append(guardrail)
elif isinstance(guardrail, str):
from crewai.tasks.llm_guardrail import LLMGuardrail

guardrails.append(
LLMGuardrail(description=guardrail, llm=self.agent.llm)
)
else:
raise ValueError("Guardrail must be a callable or a string")

self._guardrails = guardrails
if self._guardrails:
self.guardrail = None
self._guardrail = None

return self

@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
Expand Down Expand Up @@ -456,48 +503,23 @@ def _execute_core(
output_format=self._get_output_format(),
)

if self._guardrails:
for guardrail in self._guardrails:
task_output = self._invoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=guardrail,
)

# backwards support
if self._guardrail:
guardrail_result = process_guardrail(
output=task_output,
task_output = self._invoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=self._guardrail,
retry_count=self.retry_count,
event_source=self,
from_task=self,
from_agent=agent,
)
if not guardrail_result.success:
if self.retry_count >= self.guardrail_max_retries:
raise Exception(
f"Task failed guardrail validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)

self.retry_count += 1
context = self.i18n.errors("validation_error").format(
guardrail_result_error=guardrail_result.error,
task_output=task_output.raw,
)
printer = Printer()
printer.print(
content=f"Guardrail blocked, retrying, due to: {guardrail_result.error}\n",
color="yellow",
)
return self._execute_core(agent, context, tools)

if guardrail_result.result is None:
raise Exception(
"Task guardrail returned None as result. This is not allowed."
)

if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
guardrail_result.result
)
task_output.pydantic = pydantic_output
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result

self.output = task_output
self.end_time = datetime.datetime.now()
Expand Down Expand Up @@ -789,3 +811,55 @@ def fingerprint(self) -> Fingerprint:
Fingerprint: The fingerprint of the task
"""
return self.security_config.fingerprint

def _invoke_guardrail_function(
self,
task_output: TaskOutput,
agent: BaseAgent,
tools: list[BaseTool],
guardrail: Callable | None,
) -> TaskOutput:
if guardrail:
guardrail_result = process_guardrail(
output=task_output,
guardrail=guardrail,
retry_count=self.retry_count,
event_source=self,
from_task=self,
from_agent=agent,
)
if not guardrail_result.success:
if self.retry_count >= self.guardrail_max_retries:
raise Exception(
f"Task failed guardrail validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)

self.retry_count += 1
context = self.i18n.errors("validation_error").format(
guardrail_result_error=guardrail_result.error,
task_output=task_output.raw,
)
printer = Printer()
printer.print(
content=f"Guardrail blocked, retrying, due to: {guardrail_result.error}\n",
color="yellow",
)
return self._execute_core(agent, context, tools)

if guardrail_result.result is None:
raise Exception(
"Task guardrail returned None as result. This is not allowed."
)

if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
guardrail_result.result
)
task_output.pydantic = pydantic_output
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result

return task_output
61 changes: 32 additions & 29 deletions tests/test_crew_thread_safety.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import threading
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, Any, Callable
from typing import Any
from unittest.mock import patch

import pytest
Expand All @@ -24,20 +25,22 @@ def create_agent(name: str) -> Agent:

@pytest.fixture
def simple_task_factory():
def create_task(name: str, callback: Callable = None) -> Task:
def create_task(name: str, agent: Agent, callback: Callable | None = None) -> Task:
return Task(
description=f"Task for {name}", expected_output="Done", callback=callback
description=f"Task for {name}",
expected_output="Done",
agent=agent,
callback=callback,
)

return create_task


@pytest.fixture
def crew_factory(simple_agent_factory, simple_task_factory):
def create_crew(name: str, task_callback: Callable = None) -> Crew:
def create_crew(name: str, task_callback: Callable | None = None) -> Crew:
agent = simple_agent_factory(name)
task = simple_task_factory(name, callback=task_callback)
task.agent = agent
task = simple_task_factory(name, agent=agent, callback=task_callback)

return Crew(agents=[agent], tasks=[task], verbose=False)

Expand All @@ -50,7 +53,7 @@ def test_parallel_crews_thread_safety(self, mock_execute_task, crew_factory):
mock_execute_task.return_value = "Task completed"
num_crews = 5

def run_crew_with_context_check(crew_id: str) -> Dict[str, Any]:
def run_crew_with_context_check(crew_id: str) -> dict[str, Any]:
results = {"crew_id": crew_id, "contexts": []}

def check_context_task(output):
Expand Down Expand Up @@ -105,36 +108,36 @@ def check_context_task(output):
before_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff"
)
assert (
before_ctx["crew_id"] is None
), f"Context should be None before kickoff for {result['crew_id']}"
assert before_ctx["crew_id"] is None, (
f"Context should be None before kickoff for {result['crew_id']}"
)

task_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback"
)
assert (
task_ctx["crew_id"] == crew_uuid
), f"Context mismatch during task for {result['crew_id']}"
assert task_ctx["crew_id"] == crew_uuid, (
f"Context mismatch during task for {result['crew_id']}"
)

after_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff"
)
assert (
after_ctx["crew_id"] is None
), f"Context should be None after kickoff for {result['crew_id']}"
assert after_ctx["crew_id"] is None, (
f"Context should be None after kickoff for {result['crew_id']}"
)

thread_name = before_ctx["thread"]
assert (
"ThreadPoolExecutor" in thread_name
), f"Should run in thread pool for {result['crew_id']}"
assert "ThreadPoolExecutor" in thread_name, (
f"Should run in thread pool for {result['crew_id']}"
)

@pytest.mark.asyncio
@patch("crewai.Agent.execute_task")
async def test_async_crews_thread_safety(self, mock_execute_task, crew_factory):
mock_execute_task.return_value = "Task completed"
num_crews = 5

async def run_crew_async(crew_id: str) -> Dict[str, Any]:
async def run_crew_async(crew_id: str) -> dict[str, Any]:
task_context = {"crew_id": crew_id, "context": None}

def capture_context(output):
Expand Down Expand Up @@ -162,12 +165,12 @@ def capture_context(output):
crew_uuid = result["crew_uuid"]
task_ctx = result["task_context"]["context"]

assert (
task_ctx is not None
), f"Context should exist during task for {result['crew_id']}"
assert (
task_ctx["crew_id"] == crew_uuid
), f"Context mismatch for {result['crew_id']}"
assert task_ctx is not None, (
f"Context should exist during task for {result['crew_id']}"
)
assert task_ctx["crew_id"] == crew_uuid, (
f"Context mismatch for {result['crew_id']}"
)

@patch("crewai.Agent.execute_task")
def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory):
Expand All @@ -193,9 +196,9 @@ def capture_context(output):
assert len(contexts_captured) == len(inputs)

context_ids = [ctx["context_id"] for ctx in contexts_captured]
assert len(set(context_ids)) == len(
inputs
), "Each execution should have unique context"
assert len(set(context_ids)) == len(inputs), (
"Each execution should have unique context"
)

@patch("crewai.Agent.execute_task")
def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory):
Expand Down
Loading