Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,14 +255,14 @@ There are two hook scopes:
The callback context also changes depending on the event:

- Agent start/end hooks receive [`AgentHookContext`][agents.run_context.AgentHookContext], which wraps your original context and carries the shared run usage state.
- LLM, tool, and handoff hooks receive [`RunContextWrapper`][agents.run_context.RunContextWrapper].
- Tool start/end hooks receive [`ToolContext`][agents.tool_context.ToolContext], a subclass of `RunContextWrapper` that adds tool-call metadata such as `tool_call_id`, `tool_name`, and `tool_arguments`.
- LLM and handoff hooks receive [`RunContextWrapper`][agents.run_context.RunContextWrapper].

Typical hook timing:

- `on_agent_start` / `on_agent_end`: when a specific agent begins or finishes producing a final output.
- `on_llm_start` / `on_llm_end`: immediately around each model call.
- `on_tool_start` / `on_tool_end`: around each local tool invocation.
For function tools, the hook `context` is typically a `ToolContext`, so you can inspect tool-call metadata such as `tool_call_id`.
- `on_tool_start` / `on_tool_end`: around each local tool invocation. The hook `context` is a `ToolContext`, so you can inspect `tool_call_id` to correlate the start and end of the same call when tools run in parallel.
- `on_handoff`: when control moves from one agent to another.

Use `RunHooks` when you want a single observer for the whole workflow, and `AgentHooks` when one agent needs custom side effects.
Expand Down
41 changes: 21 additions & 20 deletions src/agents/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .items import ModelResponse, TResponseInputItem
from .run_context import AgentHookContext, RunContextWrapper, TContext
from .tool import Tool
from .tool_context import ToolContext

TAgent = TypeVar("TAgent", bound=AgentBase, default=AgentBase)

Expand Down Expand Up @@ -69,32 +70,32 @@ async def on_handoff(

async def on_tool_start(
self,
context: RunContextWrapper[TContext],
context: ToolContext[TContext],
agent: TAgent,
tool: Tool,
) -> None:
"""Called immediately before a local tool is invoked.

For function-tool invocations, ``context`` is typically a ``ToolContext`` instance,
which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``,
and ``tool_arguments``. Other local tool families may provide a plain
``RunContextWrapper`` instead.
``context`` is always a ``ToolContext`` instance, which exposes tool-call-specific
metadata such as ``tool_call_id``, ``tool_name``, and ``tool_arguments``. The
``tool_call_id`` lets hooks correlate ``on_tool_start`` with the matching
``on_tool_end`` even when tools execute in parallel.
"""
pass

async def on_tool_end(
self,
context: RunContextWrapper[TContext],
context: ToolContext[TContext],
agent: TAgent,
tool: Tool,
result: str,
) -> None:
"""Called immediately after a local tool is invoked.

For function-tool invocations, ``context`` is typically a ``ToolContext`` instance,
which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``,
and ``tool_arguments``. Other local tool families may provide a plain
``RunContextWrapper`` instead.
``context`` is always a ``ToolContext`` instance, which exposes tool-call-specific
metadata such as ``tool_call_id``, ``tool_name``, and ``tool_arguments``. The
``tool_call_id`` lets hooks correlate ``on_tool_start`` with the matching
``on_tool_end`` even when tools execute in parallel.
"""
pass

Expand Down Expand Up @@ -143,32 +144,32 @@ async def on_handoff(

async def on_tool_start(
self,
context: RunContextWrapper[TContext],
context: ToolContext[TContext],
agent: TAgent,
tool: Tool,
) -> None:
"""Called immediately before a local tool is invoked.

For function-tool invocations, ``context`` is typically a ``ToolContext`` instance,
which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``,
and ``tool_arguments``. Other local tool families may provide a plain
``RunContextWrapper`` instead.
``context`` is always a ``ToolContext`` instance, which exposes tool-call-specific
metadata such as ``tool_call_id``, ``tool_name``, and ``tool_arguments``. The
``tool_call_id`` lets hooks correlate ``on_tool_start`` with the matching
``on_tool_end`` even when tools execute in parallel.
"""
pass

async def on_tool_end(
self,
context: RunContextWrapper[TContext],
context: ToolContext[TContext],
agent: TAgent,
tool: Tool,
result: str,
) -> None:
"""Called immediately after a local tool is invoked.

For function-tool invocations, ``context`` is typically a ``ToolContext`` instance,
which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``,
and ``tool_arguments``. Other local tool families may provide a plain
``RunContextWrapper`` instead.
``context`` is always a ``ToolContext`` instance, which exposes tool-call-specific
metadata such as ``tool_call_id``, ``tool_name``, and ``tool_arguments``. The
``tool_call_id`` lets hooks correlate ``on_tool_start`` with the matching
``on_tool_end`` even when tools execute in parallel.
"""
pass

Expand Down
90 changes: 74 additions & 16 deletions src/agents/run_internal/tool_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,19 @@ async def execute(
"""Run a computer action, capturing a screenshot and notifying hooks."""
trace_tool_name = get_tool_trace_name_for_tool(action.computer_tool) or cls.TRACE_TOOL_NAME

# Build a ToolContext so lifecycle hooks can attribute parallel tool calls via
# ``tool_call_id``. Computer tools encode their arguments in structured action
# fields rather than a JSON string, so serialize the actions for visibility.
tool_arguments = _serialize_trace_payload(cls._get_trace_input_payload(action.tool_call))
tool_context = ToolContext.from_agent_context(
context_wrapper,
action.tool_call.call_id,
tool_name=action.computer_tool.name,
tool_arguments=tool_arguments,
agent=agent,
run_config=config,
)

async def _run_action(span: Any | None) -> RunItem:
if span and config.trace_include_sensitive_data:
span.span_data.input = _serialize_trace_payload(
Expand All @@ -121,9 +134,9 @@ async def _run_action(span: Any | None) -> RunItem:
)
agent_hooks = agent.hooks
await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, action.computer_tool),
hooks.on_tool_start(tool_context, agent, action.computer_tool),
(
agent_hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
agent_hooks.on_tool_start(tool_context, agent, action.computer_tool)
if agent_hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -151,9 +164,9 @@ async def _run_action(span: Any | None) -> RunItem:
raise

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output),
hooks.on_tool_end(tool_context, agent, action.computer_tool, output),
(
agent_hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output)
agent_hooks.on_tool_end(tool_context, agent, action.computer_tool, output)
if agent_hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -374,10 +387,22 @@ async def execute(
) -> RunItem:
"""Run a local shell tool call and wrap the result as a ToolCallOutputItem."""
agent_hooks = agent.hooks
# Build a ToolContext so lifecycle hooks can attribute parallel tool calls via
# ``tool_call_id``. Local shell calls carry structured action fields, so we
# serialize the action payload for ``tool_arguments`` visibility.
tool_arguments = _serialize_trace_payload(call.tool_call.action)
tool_context = ToolContext.from_agent_context(
context_wrapper,
call.tool_call.call_id,
tool_name=call.local_shell_tool.name,
tool_arguments=tool_arguments,
agent=agent,
run_config=config,
)
await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool),
hooks.on_tool_start(tool_context, agent, call.local_shell_tool),
(
agent_hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool)
agent_hooks.on_tool_start(tool_context, agent, call.local_shell_tool)
if agent_hooks
else _coro.noop_coroutine()
),
Expand All @@ -391,9 +416,9 @@ async def execute(
result = await output if inspect.isawaitable(output) else output

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result),
hooks.on_tool_end(tool_context, agent, call.local_shell_tool, result),
(
agent_hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result)
agent_hooks.on_tool_end(tool_context, agent, call.local_shell_tool, result)
if agent_hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -428,6 +453,18 @@ async def execute(
shell_call = coerce_shell_call(call.tool_call)
shell_tool = call.shell_tool
agent_hooks = agent.hooks
# Build a ToolContext so lifecycle hooks can attribute parallel tool calls via
# ``tool_call_id``. Shell calls carry a structured action payload, so serialize
# it for ``tool_arguments`` visibility.
tool_arguments = _serialize_trace_payload(dataclasses.asdict(shell_call.action))
tool_context = ToolContext.from_agent_context(
context_wrapper,
shell_call.call_id,
tool_name=shell_tool.name,
tool_arguments=tool_arguments,
agent=agent,
run_config=config,
)

async def _run_call(span: Any | None) -> RunItem:
if span and config.trace_include_sensitive_data:
Expand Down Expand Up @@ -467,9 +504,9 @@ async def _run_call(span: Any | None) -> RunItem:
return approval_item

await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, shell_tool),
hooks.on_tool_start(tool_context, agent, shell_tool),
(
agent_hooks.on_tool_start(context_wrapper, agent, shell_tool)
agent_hooks.on_tool_start(tool_context, agent, shell_tool)
if agent_hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -541,9 +578,9 @@ async def _run_call(span: Any | None) -> RunItem:
logger.error("Shell executor failed: %s", exc, exc_info=True)

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text),
hooks.on_tool_end(tool_context, agent, call.shell_tool, output_text),
(
agent_hooks.on_tool_end(context_wrapper, agent, call.shell_tool, output_text)
agent_hooks.on_tool_end(tool_context, agent, call.shell_tool, output_text)
if agent_hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -747,6 +784,27 @@ async def execute(
context_wrapper=context_wrapper,
)
call_id = extract_apply_patch_call_id(call.tool_call)
# Build a ToolContext so lifecycle hooks can attribute parallel tool calls via
# ``tool_call_id``. Apply patch operations carry structured fields, so serialize
# the operation list for ``tool_arguments`` visibility.
tool_arguments = _serialize_trace_payload(
[
{
"type": operation.type,
"path": operation.path,
"diff": operation.diff,
}
for operation in operations
]
)
tool_context = ToolContext.from_agent_context(
context_wrapper,
call_id,
tool_name=apply_patch_tool.name,
tool_arguments=tool_arguments,
agent=agent,
run_config=config,
)

async def _run_call(span: Any | None) -> RunItem:
if span and config.trace_include_sensitive_data:
Expand Down Expand Up @@ -798,9 +856,9 @@ async def _run_call(span: Any | None) -> RunItem:
return approval_item

await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, apply_patch_tool),
hooks.on_tool_start(tool_context, agent, apply_patch_tool),
(
agent_hooks.on_tool_start(context_wrapper, agent, apply_patch_tool)
agent_hooks.on_tool_start(tool_context, agent, apply_patch_tool)
if agent_hooks
else _coro.noop_coroutine()
),
Expand Down Expand Up @@ -854,9 +912,9 @@ async def _run_call(span: Any | None) -> RunItem:
logger.error("Apply patch editor failed: %s", exc, exc_info=True)

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text),
hooks.on_tool_end(tool_context, agent, apply_patch_tool, output_text),
(
agent_hooks.on_tool_end(context_wrapper, agent, apply_patch_tool, output_text)
agent_hooks.on_tool_end(tool_context, agent, apply_patch_tool, output_text)
if agent_hooks
else _coro.noop_coroutine()
),
Expand Down
46 changes: 46 additions & 0 deletions tests/test_apply_patch_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,49 @@ class MultiOpCall:
assert raw_item["status"] == "failed"
assert "Failed a.md" in result.output
assert "Created b.md" in result.output


@pytest.mark.asyncio
async def test_apply_patch_action_exposes_tool_call_id_on_lifecycle_hooks() -> None:
"""Apply patch lifecycle hooks must receive a ToolContext with tool_call_id set."""
from agents.tool_context import ToolContext

captured: dict[str, list[Any]] = {"start": [], "end": []}

class CapturingHooks(RunHooks[Any]):
async def on_tool_start(
self, context: RunContextWrapper[Any], agent: Agent[Any], tool: Any
) -> None:
captured["start"].append(context)

async def on_tool_end(
self,
context: RunContextWrapper[Any],
agent: Agent[Any],
tool: Any,
result: str,
) -> None:
captured["end"].append(context)

editor = RecordingEditor()
tool = ApplyPatchTool(editor=editor)
agent, context_wrapper, tool_run = build_apply_patch_call(
tool,
"call_apply_ctx",
{"type": "update_file", "path": "tasks.md", "diff": "-a\n+b\n"},
)

await ApplyPatchAction.execute(
agent=agent,
call=tool_run,
hooks=CapturingHooks(),
context_wrapper=context_wrapper,
config=RunConfig(),
)

assert len(captured["start"]) == 1
assert len(captured["end"]) == 1
for context in (captured["start"][0], captured["end"][0]):
assert isinstance(context, ToolContext)
assert context.tool_call_id == "call_apply_ctx"
assert context.tool_name == tool.name
Loading