Skip to content

Commit

Permalink
feat: add indictor for tool failure to FunctionExecutionResult (#5428)
Browse files Browse the repository at this point in the history
Some LLMs recieve an explicit signal about tool use failures. 

Closes #5273

Co-authored-by: Eric Zhu <[email protected]>
  • Loading branch information
wistuba and ekzhu authored Feb 10, 2025
1 parent b8c5e49 commit 7a772a2
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -557,9 +557,9 @@ async def _execute_tool_call(
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id)
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id, is_error=False)
except Exception as e:
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id)
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id, is_error=True)

async def on_reset(self, cancellation_token: CancellationToken) -> None:
"""Reset the assistant agent to its initialization state."""
Expand Down
10 changes: 5 additions & 5 deletions python/packages/autogen-agentchat/tests/test_assistant_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,9 +394,9 @@ async def test_run_with_parallel_tools(monkeypatch: pytest.MonkeyPatch) -> None:
assert result.messages[1].models_usage.prompt_tokens == 10
assert isinstance(result.messages[2], ToolCallExecutionEvent)
expected_content = [
FunctionExecutionResult(call_id="1", content="pass"),
FunctionExecutionResult(call_id="2", content="pass"),
FunctionExecutionResult(call_id="3", content="task3"),
FunctionExecutionResult(call_id="1", content="pass", is_error=False),
FunctionExecutionResult(call_id="2", content="pass", is_error=False),
FunctionExecutionResult(call_id="3", content="task3", is_error=False),
]
for expected in expected_content:
assert expected in result.messages[2].content
Expand Down Expand Up @@ -877,8 +877,8 @@ async def test_model_client_stream_with_tool_calls() -> None:
FunctionCall(id="3", name="_echo_function", arguments=r'{"input": "task"}'),
]
assert message.messages[2].content == [
FunctionExecutionResult(call_id="1", content="pass"),
FunctionExecutionResult(call_id="3", content="task"),
FunctionExecutionResult(call_id="1", content="pass", is_error=False),
FunctionExecutionResult(call_id="3", content="task", is_error=False),
]
elif isinstance(message, ModelClientStreamingChunkEvent):
chunks.append(message.content)
Expand Down
4 changes: 2 additions & 2 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1134,8 +1134,8 @@ async def test_swarm_with_parallel_tool_calls(monkeypatch: pytest.MonkeyPatch) -
),
FunctionExecutionResultMessage(
content=[
FunctionExecutionResult(content="tool1", call_id="1"),
FunctionExecutionResult(content="tool2", call_id="2"),
FunctionExecutionResult(content="tool1", call_id="1", is_error=False),
FunctionExecutionResult(content="tool2", call_id="2", is_error=False),
]
),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,7 @@ def convert_to_v04_message(message: Dict[str, Any]) -> AgentEvent | ChatMessage:
FunctionExecutionResult(
call_id=tool_response["tool_call_id"],
content=tool_response["content"],
is_error=False,
)
)
return ToolCallExecutionEvent(source="tools", content=tool_results)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@
" # Execute the tool directly.\n",
" result = await self._tools[call.name].run_json(arguments, ctx.cancellation_token)\n",
" result_as_str = self._tools[call.name].return_value_as_string(result)\n",
" tool_call_results.append(FunctionExecutionResult(call_id=call.id, content=result_as_str))\n",
" tool_call_results.append(\n",
" FunctionExecutionResult(call_id=call.id, content=result_as_str, is_error=False)\n",
" )\n",
" elif call.name in self._delegate_tools:\n",
" # Execute the tool to get the delegate agent's topic type.\n",
" result = await self._delegate_tools[call.name].run_json(arguments, ctx.cancellation_token)\n",
Expand All @@ -194,7 +196,9 @@
" FunctionExecutionResultMessage(\n",
" content=[\n",
" FunctionExecutionResult(\n",
" call_id=call.id, content=f\"Transfered to {topic_type}. Adopt persona immediately.\"\n",
" call_id=call.id,\n",
" content=f\"Transferred to {topic_type}. Adopt persona immediately.\",\n",
" is_error=False,\n",
" )\n",
" ]\n",
" ),\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class FunctionExecutionResult(BaseModel):

content: str
call_id: str
is_error: bool | None = None


class FunctionExecutionResultMessage(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ async def tool_agent_caller_loop(
if isinstance(result, FunctionExecutionResult):
function_results.append(result)
elif isinstance(result, ToolException):
function_results.append(FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id))
function_results.append(
FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id, is_error=True)
)
elif isinstance(result, BaseException):
raise result # Unexpected exception.
generated_messages.append(FunctionExecutionResultMessage(content=function_results))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,4 @@ async def handle_function_call(self, message: FunctionCall, ctx: MessageContext)
) from e
except Exception as e:
raise ToolExecutionException(call_id=message.id, content=f"Error: {e}") from e
return FunctionExecutionResult(content=result_as_str, call_id=message.id)
return FunctionExecutionResult(content=result_as_str, call_id=message.id, is_error=False)
2 changes: 1 addition & 1 deletion python/packages/autogen-core/tests/test_tool_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def test_tool_agent() -> None:
result = await runtime.send_message(
FunctionCall(id="1", arguments=json.dumps({"input": "pass"}), name="pass"), agent
)
assert result == FunctionExecutionResult(call_id="1", content="pass")
assert result == FunctionExecutionResult(call_id="1", content="pass", is_error=False)

# Test raise function
with pytest.raises(ToolExecutionException):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -381,17 +381,14 @@ def _thread_id(self) -> str:

async def _execute_tool_call(self, tool_call: FunctionCall, cancellation_token: CancellationToken) -> str:
"""Execute a tool call and return the result."""
try:
if not self._original_tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._original_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
return tool.return_value_as_string(result)
except Exception as e:
return f"Error: {e}"
if not self._original_tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._original_tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
return tool.return_value_as_string(result)

async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handle incoming messages and return a response."""
Expand Down Expand Up @@ -463,8 +460,15 @@ async def on_messages_stream(
# Execute tool calls and get results
tool_outputs: List[FunctionExecutionResult] = []
for tool_call in tool_calls:
result = await self._execute_tool_call(tool_call, cancellation_token)
tool_outputs.append(FunctionExecutionResult(content=result, call_id=tool_call.id))
try:
result = await self._execute_tool_call(tool_call, cancellation_token)
is_error = False
except Exception as e:
result = f"Error: {e}"
is_error = True
tool_outputs.append(
FunctionExecutionResult(content=result, call_id=tool_call.id, is_error=is_error)
)

# Add tool result message to inner messages
tool_result_msg = ToolCallExecutionEvent(source=self.name, content=tool_outputs)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ async def test_openai_chat_completion_client_count_tokens(monkeypatch: pytest.Mo
],
source="user",
),
FunctionExecutionResultMessage(content=[FunctionExecutionResult(content="Hello", call_id="1")]),
FunctionExecutionResultMessage(content=[FunctionExecutionResult(content="Hello", call_id="1", is_error=False)]),
]

def tool1(test: str, test2: str) -> str:
Expand Down Expand Up @@ -902,7 +902,7 @@ async def _test_model_client_with_function_calling(model_client: OpenAIChatCompl
messages.append(AssistantMessage(content=create_result.content, source="assistant"))
messages.append(
FunctionExecutionResultMessage(
content=[FunctionExecutionResult(content="passed", call_id=create_result.content[0].id)]
content=[FunctionExecutionResult(content="passed", call_id=create_result.content[0].id, is_error=False)]
)
)
create_result = await model_client.create(messages=messages)
Expand Down Expand Up @@ -932,8 +932,8 @@ async def _test_model_client_with_function_calling(model_client: OpenAIChatCompl
messages.append(
FunctionExecutionResultMessage(
content=[
FunctionExecutionResult(content="passed", call_id=create_result.content[0].id),
FunctionExecutionResult(content="failed", call_id=create_result.content[1].id),
FunctionExecutionResult(content="passed", call_id=create_result.content[0].id, is_error=False),
FunctionExecutionResult(content="failed", call_id=create_result.content[1].id, is_error=True),
]
)
)
Expand Down

0 comments on commit 7a772a2

Please sign in to comment.