diff --git a/docs/agents.md b/docs/agents.md index f1878559b2..fab7017286 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -255,14 +255,14 @@ There are two hook scopes: The callback context also changes depending on the event: - Agent start/end hooks receive [`AgentHookContext`][agents.run_context.AgentHookContext], which wraps your original context and carries the shared run usage state. -- LLM, tool, and handoff hooks receive [`RunContextWrapper`][agents.run_context.RunContextWrapper]. +- Tool start/end hooks receive [`ToolContext`][agents.tool_context.ToolContext], a subclass of `RunContextWrapper` that adds tool-call metadata such as `tool_call_id`, `tool_name`, and `tool_arguments`. +- LLM and handoff hooks receive [`RunContextWrapper`][agents.run_context.RunContextWrapper]. Typical hook timing: - `on_agent_start` / `on_agent_end`: when a specific agent begins or finishes producing a final output. - `on_llm_start` / `on_llm_end`: immediately around each model call. -- `on_tool_start` / `on_tool_end`: around each local tool invocation. - For function tools, the hook `context` is typically a `ToolContext`, so you can inspect tool-call metadata such as `tool_call_id`. +- `on_tool_start` / `on_tool_end`: around each local tool invocation. The hook `context` is a `ToolContext`, so you can inspect `tool_call_id` to correlate the start and end of the same call when tools run in parallel. - `on_handoff`: when control moves from one agent to another. Use `RunHooks` when you want a single observer for the whole workflow, and `AgentHooks` when one agent needs custom side effects. diff --git a/src/agents/lifecycle.py b/src/agents/lifecycle.py index 2ca7484739..688c5a86d9 100644 --- a/src/agents/lifecycle.py +++ b/src/agents/lifecycle.py @@ -6,6 +6,7 @@ from .items import ModelResponse, TResponseInputItem from .run_context import AgentHookContext, RunContextWrapper, TContext from .tool import Tool +from .tool_context import ToolContext TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase) @@ -69,32 +70,32 @@ async def on_handoff( async def on_tool_start( self, - context: RunContextWrapper[TContext], + context: ToolContext[TContext], agent: TAgent, tool: Tool, ) -> None: """Called immediately before a local tool is invoked. - For function-tool invocations, ``context`` is typically a ``ToolContext`` instance, - which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``, - and ``tool_arguments``. Other local tool families may provide a plain - ``RunContextWrapper`` instead. + ``context`` is always a ``ToolContext`` instance, which exposes tool-call-specific + metadata such as ``tool_call_id``, ``tool_name``, and ``tool_arguments``. The + ``tool_call_id`` lets hooks correlate ``on_tool_start`` with the matching + ``on_tool_end`` even when tools execute in parallel. """ pass async def on_tool_end( self, - context: RunContextWrapper[TContext], + context: ToolContext[TContext], agent: TAgent, tool: Tool, result: str, ) -> None: """Called immediately after a local tool is invoked. - For function-tool invocations, ``context`` is typically a ``ToolContext`` instance, - which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``, - and ``tool_arguments``. Other local tool families may provide a plain - ``RunContextWrapper`` instead. + ``context`` is always a ``ToolContext`` instance, which exposes tool-call-specific + metadata such as ``tool_call_id``, ``tool_name``, and ``tool_arguments``. The + ``tool_call_id`` lets hooks correlate ``on_tool_start`` with the matching + ``on_tool_end`` even when tools execute in parallel. """ pass @@ -143,32 +144,32 @@ async def on_handoff( async def on_tool_start( self, - context: RunContextWrapper[TContext], + context: ToolContext[TContext], agent: TAgent, tool: Tool, ) -> None: """Called immediately before a local tool is invoked. - For function-tool invocations, ``context`` is typically a ``ToolContext`` instance, - which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``, - and ``tool_arguments``. Other local tool families may provide a plain - ``RunContextWrapper`` instead. + ``context`` is always a ``ToolContext`` instance, which exposes tool-call-specific + metadata such as ``tool_call_id``, ``tool_name``, and ``tool_arguments``. The + ``tool_call_id`` lets hooks correlate ``on_tool_start`` with the matching + ``on_tool_end`` even when tools execute in parallel. """ pass async def on_tool_end( self, - context: RunContextWrapper[TContext], + context: ToolContext[TContext], agent: TAgent, tool: Tool, result: str, ) -> None: """Called immediately after a local tool is invoked. - For function-tool invocations, ``context`` is typically a ``ToolContext`` instance, - which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``, - and ``tool_arguments``. Other local tool families may provide a plain - ``RunContextWrapper`` instead. + ``context`` is always a ``ToolContext`` instance, which exposes tool-call-specific + metadata such as ``tool_call_id``, ``tool_name``, and ``tool_arguments``. The + ``tool_call_id`` lets hooks correlate ``on_tool_start`` with the matching + ``on_tool_end`` even when tools execute in parallel. """ pass diff --git a/src/agents/run_internal/tool_actions.py b/src/agents/run_internal/tool_actions.py index 310fdc2592..d2cfae8a44 100644 --- a/src/agents/run_internal/tool_actions.py +++ b/src/agents/run_internal/tool_actions.py @@ -110,6 +110,19 @@ async def execute( """Run a computer action, capturing a screenshot and notifying hooks.""" trace_tool_name = get_tool_trace_name_for_tool(action.computer_tool) or cls.TRACE_TOOL_NAME + # Build a ToolContext so lifecycle hooks can attribute parallel tool calls via + # ``tool_call_id``. Computer tools encode their arguments in structured action + # fields rather than a JSON string, so serialize the actions for visibility. + tool_arguments = _serialize_trace_payload(cls._get_trace_input_payload(action.tool_call)) + tool_context = ToolContext.from_agent_context( + context_wrapper, + action.tool_call.call_id, + tool_name=action.computer_tool.name, + tool_arguments=tool_arguments, + agent=agent, + run_config=config, + ) + async def _run_action(span: Any | None) -> RunItem: if span and config.trace_include_sensitive_data: span.span_data.input = _serialize_trace_payload( @@ -121,9 +134,9 @@ async def _run_action(span: Any | None) -> RunItem: ) agent_hooks = agent.hooks await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, action.computer_tool), + hooks.on_tool_start(tool_context, agent, action.computer_tool), ( - agent_hooks.on_tool_start(context_wrapper, agent, action.computer_tool) + agent_hooks.on_tool_start(tool_context, agent, action.computer_tool) if agent_hooks else _coro.noop_coroutine() ), @@ -151,9 +164,9 @@ async def _run_action(span: Any | None) -> RunItem: raise await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output), + hooks.on_tool_end(tool_context, agent, action.computer_tool, output), ( - agent_hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output) + agent_hooks.on_tool_end(tool_context, agent, action.computer_tool, output) if agent_hooks else _coro.noop_coroutine() ), @@ -374,10 +387,22 @@ async def execute( ) -> RunItem: """Run a local shell tool call and wrap the result as a ToolCallOutputItem.""" agent_hooks = agent.hooks + # Build a ToolContext so lifecycle hooks can attribute parallel tool calls via + # ``tool_call_id``. Local shell calls carry structured action fields, so we + # serialize the action payload for ``tool_arguments`` visibility. + tool_arguments = _serialize_trace_payload(call.tool_call.action) + tool_context = ToolContext.from_agent_context( + context_wrapper, + call.tool_call.call_id, + tool_name=call.local_shell_tool.name, + tool_arguments=tool_arguments, + agent=agent, + run_config=config, + ) await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool), + hooks.on_tool_start(tool_context, agent, call.local_shell_tool), ( - agent_hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool) + agent_hooks.on_tool_start(tool_context, agent, call.local_shell_tool) if agent_hooks else _coro.noop_coroutine() ), @@ -391,9 +416,9 @@ async def execute( result = await output if inspect.isawaitable(output) else output await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result), + hooks.on_tool_end(tool_context, agent, call.local_shell_tool, result), ( - agent_hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result) + agent_hooks.on_tool_end(tool_context, agent, call.local_shell_tool, result) if agent_hooks else _coro.noop_coroutine() ), @@ -428,6 +453,18 @@ async def execute( shell_call = coerce_shell_call(call.tool_call) shell_tool = call.shell_tool agent_hooks = agent.hooks + # Build a ToolContext so lifecycle hooks can attribute parallel tool calls via + # ``tool_call_id``. Shell calls carry a structured action payload, so serialize + # it for ``tool_arguments`` visibility. + tool_arguments = _serialize_trace_payload(dataclasses.asdict(shell_call.action)) + tool_context = ToolContext.from_agent_context( + context_wrapper, + shell_call.call_id, + tool_name=shell_tool.name, + tool_arguments=tool_arguments, + agent=agent, + run_config=config, + ) async def _run_call(span: Any | None) -> RunItem: if span and config.trace_include_sensitive_data: @@ -467,9 +504,9 @@ async def _run_call(span: Any | None) -> RunItem: return approval_item await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, shell_tool), + hooks.on_tool_start(tool_context, agent, shell_tool), ( - agent_hooks.on_tool_start(context_wrapper, agent, shell_tool) + agent_hooks.on_tool_start(tool_context, agent, shell_tool) if agent_hooks else _coro.noop_coroutine() ), @@ -541,9 +578,9 @@ async def _run_call(span: Any | None) -> RunItem: logger.error("Shell executor failed: %s", exc, exc_info=True) await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text), + hooks.on_tool_end(tool_context, agent, call.shell_tool, output_text), ( - agent_hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text) + agent_hooks.on_tool_end(tool_context, agent, call.shell_tool, output_text) if agent_hooks else _coro.noop_coroutine() ), @@ -747,6 +784,27 @@ async def execute( context_wrapper=context_wrapper, ) call_id = extract_apply_patch_call_id(call.tool_call) + # Build a ToolContext so lifecycle hooks can attribute parallel tool calls via + # ``tool_call_id``. Apply patch operations carry structured fields, so serialize + # the operation list for ``tool_arguments`` visibility. + tool_arguments = _serialize_trace_payload( + [ + { + "type": operation.type, + "path": operation.path, + "diff": operation.diff, + } + for operation in operations + ] + ) + tool_context = ToolContext.from_agent_context( + context_wrapper, + call_id, + tool_name=apply_patch_tool.name, + tool_arguments=tool_arguments, + agent=agent, + run_config=config, + ) async def _run_call(span: Any | None) -> RunItem: if span and config.trace_include_sensitive_data: @@ -798,9 +856,9 @@ async def _run_call(span: Any | None) -> RunItem: return approval_item await asyncio.gather( - hooks.on_tool_start(context_wrapper, agent, apply_patch_tool), + hooks.on_tool_start(tool_context, agent, apply_patch_tool), ( - agent_hooks.on_tool_start(context_wrapper, agent, apply_patch_tool) + agent_hooks.on_tool_start(tool_context, agent, apply_patch_tool) if agent_hooks else _coro.noop_coroutine() ), @@ -854,9 +912,9 @@ async def _run_call(span: Any | None) -> RunItem: logger.error("Apply patch editor failed: %s", exc, exc_info=True) await asyncio.gather( - hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text), + hooks.on_tool_end(tool_context, agent, apply_patch_tool, output_text), ( - agent_hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text) + agent_hooks.on_tool_end(tool_context, agent, apply_patch_tool, output_text) if agent_hooks else _coro.noop_coroutine() ), diff --git a/tests/test_apply_patch_tool.py b/tests/test_apply_patch_tool.py index 1e66312c4a..899bb00486 100644 --- a/tests/test_apply_patch_tool.py +++ b/tests/test_apply_patch_tool.py @@ -438,3 +438,49 @@ class MultiOpCall: assert raw_item["status"] == "failed" assert "Failed a.md" in result.output assert "Created b.md" in result.output + + +@pytest.mark.asyncio +async def test_apply_patch_action_exposes_tool_call_id_on_lifecycle_hooks() -> None: + """Apply patch lifecycle hooks must receive a ToolContext with tool_call_id set.""" + from agents.tool_context import ToolContext + + captured: dict[str, list[Any]] = {"start": [], "end": []} + + class CapturingHooks(RunHooks[Any]): + async def on_tool_start( + self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any + ) -> None: + captured["start"].append(context) + + async def on_tool_end( + self, + context: RunContextWrapper[Any], + agent: Agent[Any], + tool: Any, + result: str, + ) -> None: + captured["end"].append(context) + + editor = RecordingEditor() + tool = ApplyPatchTool(editor=editor) + agent, context_wrapper, tool_run = build_apply_patch_call( + tool, + "call_apply_ctx", + {"type": "update_file", "path": "tasks.md", "diff": "-a\n+b\n"}, + ) + + await ApplyPatchAction.execute( + agent=agent, + call=tool_run, + hooks=CapturingHooks(), + context_wrapper=context_wrapper, + config=RunConfig(), + ) + + assert len(captured["start"]) == 1 + assert len(captured["end"]) == 1 + for context in (captured["start"][0], captured["end"][0]): + assert isinstance(context, ToolContext) + assert context.tool_call_id == "call_apply_ctx" + assert context.tool_name == tool.name diff --git a/tests/test_computer_action.py b/tests/test_computer_action.py index 3aa908c66c..a58f76df21 100644 --- a/tests/test_computer_action.py +++ b/tests/test_computer_action.py @@ -503,16 +503,20 @@ def __init__(self) -> None: super().__init__() self.started: list[tuple[Agent[Any], Any]] = [] self.ended: list[tuple[Agent[Any], Any, str]] = [] + self.start_contexts: list[Any] = [] + self.end_contexts: list[Any] = [] async def on_tool_start( self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any ) -> None: self.started.append((agent, tool)) + self.start_contexts.append(context) async def on_tool_end( self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str ) -> None: self.ended.append((agent, tool, result)) + self.end_contexts.append(context) class LoggingAgentHooks(AgentHooks[Any]): @@ -522,16 +526,20 @@ def __init__(self) -> None: super().__init__() self.started: list[tuple[Agent[Any], Any]] = [] self.ended: list[tuple[Agent[Any], Any, str]] = [] + self.start_contexts: list[Any] = [] + self.end_contexts: list[Any] = [] async def on_tool_start( self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any ) -> None: self.started.append((agent, tool)) + self.start_contexts.append(context) async def on_tool_end( self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any, result: str ) -> None: self.ended.append((agent, tool, result)) + self.end_contexts.append(context) @pytest.mark.asyncio @@ -593,6 +601,50 @@ async def test_execute_invokes_hooks_and_returns_tool_call_output() -> None: assert raw["output"]["image_url"].endswith("xyz") +@pytest.mark.asyncio +async def test_execute_exposes_tool_call_id_on_lifecycle_hooks() -> None: + """Computer tool lifecycle hooks must receive a ToolContext with tool_call_id set.""" + from agents.tool_context import ToolContext + + computer = LoggingComputer(screenshot_return="hookimg") + comptool = ComputerTool(computer=computer) + action = ActionClick(type="click", x=4, y=5, button="left") + tool_call = ResponseComputerToolCall( + id="computer_call_42", + type="computer_call", + action=action, + call_id="computer_call_42", + pending_safety_checks=[], + status="completed", + ) + + tool_run = ToolRunComputerAction(tool_call=tool_call, computer_tool=comptool) + agent = Agent(name="ctx_id_agent", tools=[comptool]) + agent_hooks = LoggingAgentHooks() + agent.hooks = agent_hooks + run_hooks = LoggingRunHooks() + context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None) + + await ComputerAction.execute( + agent=agent, + action=tool_run, + hooks=run_hooks, + context_wrapper=context_wrapper, + config=RunConfig(), + ) + + # Each hook should receive a ToolContext whose tool_call_id matches the tool call. + for context in ( + run_hooks.start_contexts[0], + run_hooks.end_contexts[0], + agent_hooks.start_contexts[0], + agent_hooks.end_contexts[0], + ): + assert isinstance(context, ToolContext) + assert context.tool_call_id == "computer_call_42" + assert context.tool_name == comptool.name + + @pytest.mark.asyncio async def test_execute_emits_function_span() -> None: computer = LoggingComputer(screenshot_return="trace_img") diff --git a/tests/test_local_shell_tool.py b/tests/test_local_shell_tool.py index cdc0d9a7f1..a9b3cda600 100644 --- a/tests/test_local_shell_tool.py +++ b/tests/test_local_shell_tool.py @@ -21,6 +21,7 @@ ) from agents.items import ToolCallOutputItem from agents.run_internal.run_loop import LocalShellAction, ToolRunLocalShellCall +from agents.tool_context import ToolContext from .fake_model import FakeModel from .test_responses import get_text_message @@ -156,3 +157,60 @@ async def test_runner_executes_local_shell_calls() -> None: assert result.final_output == "shell complete" assert len(result.raw_responses) == 2 + + +@pytest.mark.asyncio +async def test_local_shell_action_exposes_tool_call_id_on_lifecycle_hooks() -> None: + """Local shell lifecycle hooks must receive a ToolContext with tool_call_id set.""" + + captured: dict[str, list[Any]] = {"start": [], "end": []} + + class CapturingHooks(RunHooks[Any]): + async def on_tool_start( + self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any + ) -> None: + captured["start"].append(context) + + async def on_tool_end( + self, + context: RunContextWrapper[Any], + agent: Agent[Any], + tool: Any, + result: str, + ) -> None: + captured["end"].append(context) + + executor = RecordingLocalShellExecutor(output="ok") + tool = LocalShellTool(executor=executor) + action = LocalShellCallAction( + command=["bash", "-c", "true"], + env={}, + type="exec", + timeout_ms=1000, + working_directory="/tmp", + ) + tool_call = LocalShellCall( + id="lsh_id_test", + action=action, + call_id="call_local_shell_id", + status="completed", + type="local_shell_call", + ) + tool_run = ToolRunLocalShellCall(tool_call=tool_call, local_shell_tool=tool) + agent = Agent(name="local_shell_id_agent", tools=[tool]) + context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None) + + await LocalShellAction.execute( + agent=agent, + call=tool_run, + hooks=CapturingHooks(), + context_wrapper=context_wrapper, + config=RunConfig(), + ) + + assert len(captured["start"]) == 1 + assert len(captured["end"]) == 1 + for context in (captured["start"][0], captured["end"][0]): + assert isinstance(context, ToolContext) + assert context.tool_call_id == "call_local_shell_id" + assert context.tool_name == tool.name diff --git a/tests/test_shell_tool.py b/tests/test_shell_tool.py index f9467e2f90..74afb3c03c 100644 --- a/tests/test_shell_tool.py +++ b/tests/test_shell_tool.py @@ -811,3 +811,46 @@ async def on_approval( assert isinstance(result, ToolCallOutputItem) assert result.output == HITL_REJECTION_MSG + + +@pytest.mark.asyncio +async def test_shell_action_exposes_tool_call_id_on_lifecycle_hooks() -> None: + """Shell lifecycle hooks must receive a ToolContext with tool_call_id set.""" + from agents.tool_context import ToolContext + + captured: dict[str, list[Any]] = {"start": [], "end": []} + + class CapturingHooks(RunHooks[Any]): + async def on_tool_start( + self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any + ) -> None: + captured["start"].append(context) + + async def on_tool_end( + self, + context: RunContextWrapper[Any], + agent: Agent[Any], + tool: Any, + result: str, + ) -> None: + captured["end"].append(context) + + shell_tool = ShellTool(executor=lambda request: "ok") + tool_run = ToolRunShellCall(tool_call=_shell_call("call_shell_ctx"), shell_tool=shell_tool) + agent = Agent(name="shell-ctx-agent", tools=[shell_tool]) + context_wrapper: RunContextWrapper[Any] = RunContextWrapper(context=None) + + await ShellAction.execute( + agent=agent, + call=tool_run, + hooks=CapturingHooks(), + context_wrapper=context_wrapper, + config=RunConfig(), + ) + + assert len(captured["start"]) == 1 + assert len(captured["end"]) == 1 + for context in (captured["start"][0], captured["end"][0]): + assert isinstance(context, ToolContext) + assert context.tool_call_id == "call_shell_ctx" + assert context.tool_name == shell_tool.name