Skip to content

Commit 999e654

Browse files
authored
fix: Don't bail out if there are no tool_uses (strands-agents#1087)
Partial fix to strands-agents#1069 - previously the agent would prematurely exit if the agent generated a tool with an invalid name; this avoids that by ensuring the agent loop continues with zero tool-uses. --------- Co-authored-by: Mackenzie Zastrow <[email protected]>
1 parent 1544384 commit 999e654

File tree

5 files changed

+104
-7
lines changed

5 files changed

+104
-7
lines changed

src/strands/event_loop/event_loop.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,9 +427,6 @@ async def _handle_tool_execution(
427427

428428
validate_and_prepare_tools(message, tool_uses, tool_results, invalid_tool_use_ids)
429429
tool_uses = [tool_use for tool_use in tool_uses if tool_use.get("toolUseId") not in invalid_tool_use_ids]
430-
if not tool_uses:
431-
yield EventLoopStopEvent(stop_reason, message, agent.event_loop_metrics, invocation_state["request_state"])
432-
return
433430

434431
if agent._interrupt_state.activated:
435432
tool_results.extend(agent._interrupt_state.context["tool_results"])

tests/fixtures/mocked_model_provider.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Any, AsyncGenerator, Iterable, Optional, Type, TypedDict, TypeVar, Union
2+
from typing import Any, AsyncGenerator, Iterable, Optional, Sequence, Type, TypedDict, TypeVar, Union
33

44
from pydantic import BaseModel
55

@@ -25,8 +25,8 @@ class MockedModelProvider(Model):
2525
to stream mock responses as events.
2626
"""
2727

28-
def __init__(self, agent_responses: list[Union[Message, RedactionMessage]]):
29-
self.agent_responses = agent_responses
28+
def __init__(self, agent_responses: Sequence[Union[Message, RedactionMessage]]):
29+
self.agent_responses = [*agent_responses]
3030
self.index = 0
3131

3232
def format_chunk(self, event: Any) -> StreamEvent:

tests/strands/agent/test_agent.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,3 +2065,50 @@ def test_agent_tool_caller_interrupt(user):
20652065
exp_message = r"cannot directly call tool during interrupt"
20662066
with pytest.raises(RuntimeError, match=exp_message):
20672067
agent.tool.test_tool()
2068+
2069+
2070+
def test_agent__call__invalid_tool_name():
2071+
@strands.tool
2072+
def shell(command: str):
2073+
pass
2074+
2075+
model = MockedModelProvider(
2076+
[
2077+
{
2078+
"role": "assistant",
2079+
"content": [
2080+
{
2081+
"toolUse": {
2082+
"toolUseId": "tool_use_id",
2083+
"name": "invalid tool",
2084+
"input": "{}",
2085+
}
2086+
}
2087+
],
2088+
},
2089+
{"role": "assistant", "content": [{"text": "I invoked a tool!"}]},
2090+
]
2091+
)
2092+
2093+
agent = Agent(tools=[shell], model=model)
2094+
result = agent("Test")
2095+
2096+
# Ensure the stop_reason is
2097+
assert result.stop_reason == "end_turn"
2098+
2099+
# Assert that there exists a message with a toolResponse
2100+
assert agent.messages[-2] == {
2101+
"content": [
2102+
{
2103+
"toolResult": {
2104+
"content": [{"text": "Error: tool_name=<invalid tool> | invalid tool name pattern"}],
2105+
"status": "error",
2106+
"toolUseId": "tool_use_id",
2107+
}
2108+
}
2109+
],
2110+
"role": "user",
2111+
}
2112+
2113+
# And that it continued to the LLM call
2114+
assert agent.messages[-1] == {"content": [{"text": "I invoked a tool!"}], "role": "assistant"}

tests/strands/event_loop/test_event_loop.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import concurrent
22
import unittest.mock
3-
from unittest.mock import MagicMock, call, patch
3+
from unittest.mock import ANY, MagicMock, call, patch
44

55
import pytest
66

@@ -18,13 +18,15 @@
1818
from strands.telemetry.metrics import EventLoopMetrics
1919
from strands.tools.executors import SequentialToolExecutor
2020
from strands.tools.registry import ToolRegistry
21+
from strands.types._events import EventLoopStopEvent
2122
from strands.types.exceptions import (
2223
ContextWindowOverflowException,
2324
EventLoopException,
2425
MaxTokensReachedException,
2526
ModelThrottledException,
2627
)
2728
from tests.fixtures.mock_hook_provider import MockHookProvider
29+
from tests.fixtures.mocked_model_provider import MockedModelProvider
2830

2931

3032
@pytest.fixture
@@ -744,6 +746,8 @@ async def test_event_loop_cycle_with_parent_span(
744746
async def test_request_state_initialization(alist):
745747
# Create a mock agent
746748
mock_agent = MagicMock()
749+
# not setting this to False results in endless recursion
750+
mock_agent._interrupt_state.activated = False
747751
mock_agent.event_loop_metrics.start_cycle.return_value = (0, MagicMock())
748752

749753
# Call without providing request_state
@@ -1011,3 +1015,52 @@ def interrupt_callback(event):
10111015
"interrupts": {},
10121016
}
10131017
assert tru_state == exp_state
1018+
1019+
1020+
@pytest.mark.asyncio
1021+
async def test_invalid_tool_names_adds_tool_uses(agent, model, alist):
1022+
model.stream = MockedModelProvider(
1023+
[
1024+
{
1025+
"role": "assistant",
1026+
"content": [
1027+
{
1028+
"toolUse": {
1029+
"toolUseId": "tool_use_id",
1030+
"name": "invalid tool",
1031+
"input": "{}",
1032+
}
1033+
}
1034+
],
1035+
},
1036+
{"role": "assistant", "content": [{"text": "I invoked a tool!"}]},
1037+
]
1038+
).stream
1039+
1040+
stream = strands.event_loop.event_loop.event_loop_cycle(
1041+
agent=agent,
1042+
invocation_state={},
1043+
)
1044+
events = await alist(stream)
1045+
1046+
# ensure that we got end_turn and not tool_use
1047+
assert events[-1] == EventLoopStopEvent(
1048+
stop_reason="end_turn",
1049+
message={"content": [{"text": "I invoked a tool!"}], "role": "assistant"},
1050+
metrics=ANY,
1051+
request_state={},
1052+
)
1053+
1054+
# Ensure that an "invalid tool name" message was added properly
1055+
assert agent.messages[-2] == {
1056+
"content": [
1057+
{
1058+
"toolResult": {
1059+
"content": [{"text": "Error: tool_name=<invalid tool> | invalid tool name pattern"}],
1060+
"status": "error",
1061+
"toolUseId": "tool_use_id",
1062+
}
1063+
}
1064+
],
1065+
"role": "user",
1066+
}

tests/strands/types/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)