Skip to content

Commit

Permalink
feat: select errors that will be caught when raised within a tool
Browse files Browse the repository at this point in the history
  • Loading branch information
wistuba committed Feb 11, 2025
1 parent a9db384 commit c7b5349
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -548,18 +548,25 @@ async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
) -> FunctionExecutionResult:
"""Execute a tool call and return the result."""
error_prefix_string = "Error: "
if not self._tools + self._handoff_tools:
return FunctionExecutionResult(
content=f"{error_prefix_string}No tools are available.", call_id=tool_call.id, is_error=True
)
tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None)
if tool is None:
return FunctionExecutionResult(
content=f"{error_prefix_string}The tool '{tool_call.name}' is not available.",
call_id=tool_call.id,
is_error=True,
)
try:
if not self._tools + self._handoff_tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._tools + self._handoff_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id, is_error=False)
except Exception as e:
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id, is_error=True)
except tool.returned_errors as e:
return FunctionExecutionResult(content=f"{error_prefix_string}{e}", call_id=tool_call.id, is_error=True)

async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Reset the assistant agent to its initialization state."""
Expand Down
9 changes: 9 additions & 0 deletions python/packages/autogen-core/src/autogen_core/tools/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def description(self) -> str: ...
@property
def schema(self) -> ToolSchema: ...

@property
def returned_errors(self) -> tuple[type[Exception], ...]: ...

def args_type(self) -> Type[BaseModel]: ...

def return_type(self) -> Type[Any]: ...
Expand Down Expand Up @@ -66,12 +69,14 @@ def __init__(
return_type: Type[ReturnT],
name: str,
description: str,
returned_errors: tuple[type[Exception], ...] = (Exception,),
) -> None:
self._args_type = args_type
# Normalize Annotated to the base type.
self._return_type = normalize_annotated_type(return_type)
self._name = name
self._description = description
self._returned_errors = returned_errors

@property
def schema(self) -> ToolSchema:
Expand Down Expand Up @@ -103,6 +108,10 @@ def name(self) -> str:
def description(self) -> str:
return self._description

@property
def returned_errors(self) -> tuple[type[Exception], ...]:
return self._returned_errors

def args_type(self) -> Type[BaseModel]:
return self._args_type

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ async def example():
component_config_schema = FunctionToolConfig

def __init__(
self, func: Callable[..., Any], description: str, name: str | None = None, global_imports: Sequence[Import] = []
self,
func: Callable[..., Any],
description: str,
name: str | None = None,
global_imports: Sequence[Import] = [],
return_errors: tuple[type[Exception], ...] = (Exception,),
) -> None:
self._func = func
self._global_imports = global_imports
Expand All @@ -92,6 +97,7 @@ def __init__(
args_model = args_base_model_from_signature(func_name + "args", signature)
return_type = signature.return_annotation
self._has_cancellation_support = "cancellation_token" in signature.parameters
self._return_errors = return_errors

super().__init__(args_model, return_type, func_name, description)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,16 +379,24 @@ def _thread_id(self) -> str:
raise ValueError("Thread not initialized")
return self._thread.id

async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str:
async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
) -> FunctionExecutionResult:
"""Execute a tool call and return the result."""
if not self._original_tools:
raise ValueError("No tools are available.")
return FunctionExecutionResult(content="No tools are available.", call_id=tool_call.id, is_error=True)
tool = next((t for t in self._original_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
return tool.return_value_as_string(result)
return FunctionExecutionResult(
content=f"The tool '{tool_call.name}' is not available.", call_id=tool_call.id, is_error=True
)
try:
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id, is_error=False)
except tool.returned_errors as e:
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id, is_error=True)

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handle incoming messages and return a response."""
Expand Down Expand Up @@ -460,15 +468,8 @@ async def on_messages_stream(
# Execute tool calls and get results
tool_outputs: List[FunctionExecutionResult] = []
for tool_call in tool_calls:
try:
result = await self._execute_tool_call(tool_call, cancellation_token)
is_error = False
except Exception as e:
result = f"Error: {e}"
is_error = True
tool_outputs.append(
FunctionExecutionResult(content=result, call_id=tool_call.id, is_error=is_error)
)
tool_output = await self._execute_tool_call(tool_call, cancellation_token)
tool_outputs.append(tool_output)

# Add tool result message to inner messages
tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs)
Expand Down

0 comments on commit c7b5349

Please sign in to comment.