Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ Typical hook timing:
- `on_agent_start` / `on_agent_end`: when a specific agent begins or finishes producing a final output.
- `on_llm_start` / `on_llm_end`: immediately around each model call.
- `on_tool_start` / `on_tool_end`: around each local tool invocation.
For function tools, the hook `context` is typically a `ToolContext`, so you can inspect tool-call metadata such as `tool_call_id`.
- `on_handoff`: when control moves from one agent to another.

Use `RunHooks` when you want a single observer for the whole workflow, and `AgentHooks` when one agent needs custom side effects.
Expand Down
2 changes: 2 additions & 0 deletions docs/context.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ This is represented via the [`RunContextWrapper`][agents.run_context.RunContextW
2. You pass that object to the various run methods (e.g. `Runner.run(..., context=whatever)`).
3. All your tool calls, lifecycle hooks etc will be passed a wrapper object, `RunContextWrapper[T]`, where `T` represents your context object type which you can access via `wrapper.context`.

For some runtime-specific callbacks, the SDK may pass a more specialized subclass of `RunContextWrapper[T]`. For example, function-tool lifecycle hooks typically receive `ToolContext`, which also exposes tool-call metadata like `tool_call_id`, `tool_name`, and `tool_arguments`.

The **most important** thing to be aware of: every agent, tool function, lifecycle etc for a given agent run must use the same _type_ of context.

You can use the context for things like:
Expand Down
32 changes: 28 additions & 4 deletions src/agents/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ async def on_tool_start(
agent: TAgent,
tool: Tool,
) -> None:
"""Called immediately before a local tool is invoked."""
"""Called immediately before a local tool is invoked.

For function-tool invocations, ``context`` is typically a ``ToolContext`` instance,
which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``,
and ``tool_arguments``. Other local tool families may provide a plain
``RunContextWrapper`` instead.
"""
pass

async def on_tool_end(
Expand All @@ -83,7 +89,13 @@ async def on_tool_end(
tool: Tool,
result: str,
) -> None:
"""Called immediately after a local tool is invoked."""
"""Called immediately after a local tool is invoked.

For function-tool invocations, ``context`` is typically a ``ToolContext`` instance,
which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``,
and ``tool_arguments``. Other local tool families may provide a plain
``RunContextWrapper`` instead.
"""
pass


Expand Down Expand Up @@ -135,7 +147,13 @@ async def on_tool_start(
agent: TAgent,
tool: Tool,
) -> None:
"""Called immediately before a local tool is invoked."""
"""Called immediately before a local tool is invoked.

For function-tool invocations, ``context`` is typically a ``ToolContext`` instance,
which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``,
and ``tool_arguments``. Other local tool families may provide a plain
``RunContextWrapper`` instead.
"""
pass

async def on_tool_end(
Expand All @@ -145,7 +163,13 @@ async def on_tool_end(
tool: Tool,
result: str,
) -> None:
"""Called immediately after a local tool is invoked."""
"""Called immediately after a local tool is invoked.

For function-tool invocations, ``context`` is typically a ``ToolContext`` instance,
which exposes tool-call-specific metadata such as ``tool_call_id``, ``tool_name``,
and ``tool_arguments``. Other local tool families may provide a plain
``RunContextWrapper`` instead.
"""
pass

async def on_llm_start(
Expand Down
18 changes: 18 additions & 0 deletions tests/test_agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from agents.run import Runner
from agents.run_context import AgentHookContext, RunContextWrapper, TContext
from agents.tool import Tool
from agents.tool_context import ToolContext

from .fake_model import FakeModel
from .test_responses import (
Expand All @@ -26,9 +27,11 @@
class AgentHooksForTests(AgentHooks):
def __init__(self):
self.events: dict[str, int] = defaultdict(int)
self.tool_context_ids: list[str] = []

def reset(self):
self.events.clear()
self.tool_context_ids.clear()

async def on_start(self, context: AgentHookContext[TContext], agent: Agent[TContext]) -> None:
self.events["on_start"] += 1
Expand Down Expand Up @@ -56,6 +59,8 @@ async def on_tool_start(
tool: Tool,
) -> None:
self.events["on_tool_start"] += 1
if isinstance(context, ToolContext):
self.tool_context_ids.append(context.tool_call_id)

async def on_tool_end(
self,
Expand All @@ -65,6 +70,8 @@ async def on_tool_end(
result: str,
) -> None:
self.events["on_tool_end"] += 1
if isinstance(context, ToolContext):
self.tool_context_ids.append(context.tool_call_id)


@pytest.mark.asyncio
Expand Down Expand Up @@ -94,6 +101,17 @@ async def test_non_streamed_agent_hooks():
assert hooks.events == {"on_start": 1, "on_end": 1}, f"{output}"
hooks.reset()

model.add_multiple_turn_outputs(
[
[get_function_tool_call("some_function", json.dumps({"a": "b"}))],
[get_text_message("done")],
]
)
await Runner.run(agent_3, input="user_message")
assert len(hooks.tool_context_ids) == 2
assert len(set(hooks.tool_context_ids)) == 1
hooks.reset()

model.add_multiple_turn_outputs(
[
# First turn: a tool call
Expand Down
18 changes: 18 additions & 0 deletions tests/test_global_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing_extensions import TypedDict

from agents import Agent, RunContextWrapper, RunHooks, Runner, TContext, Tool
from agents.tool_context import ToolContext

from .fake_model import FakeModel
from .test_responses import (
Expand All @@ -22,9 +23,11 @@
class RunHooksForTests(RunHooks):
def __init__(self):
self.events: dict[str, int] = defaultdict(int)
self.tool_context_ids: list[str] = []

def reset(self):
self.events.clear()
self.tool_context_ids.clear()

async def on_agent_start(
self, context: RunContextWrapper[TContext], agent: Agent[TContext]
Expand Down Expand Up @@ -54,6 +57,8 @@ async def on_tool_start(
tool: Tool,
) -> None:
self.events["on_tool_start"] += 1
if isinstance(context, ToolContext):
self.tool_context_ids.append(context.tool_call_id)

async def on_tool_end(
self,
Expand All @@ -63,6 +68,8 @@ async def on_tool_end(
result: str,
) -> None:
self.events["on_tool_end"] += 1
if isinstance(context, ToolContext):
self.tool_context_ids.append(context.tool_call_id)


@pytest.mark.asyncio
Expand All @@ -85,6 +92,17 @@ async def test_non_streamed_agent_hooks():
assert hooks.events == {"on_agent_start": 1, "on_agent_end": 1}, f"{output}"
hooks.reset()

model.add_multiple_turn_outputs(
[
[get_function_tool_call("some_function", json.dumps({"a": "b"}))],
[get_text_message("done")],
]
)
await Runner.run(agent_3, input="user_message", hooks=hooks)
assert len(hooks.tool_context_ids) == 2
assert len(set(hooks.tool_context_ids)) == 1
hooks.reset()

model.add_multiple_turn_outputs(
[
# First turn: a tool call
Expand Down
5 changes: 5 additions & 0 deletions tests/test_run_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from agents.run import Runner
from agents.run_context import AgentHookContext, RunContextWrapper, TContext
from agents.tool import Tool
from agents.tool_context import ToolContext
from tests.test_agent_llm_hooks import AgentHooksForTests

from .fake_model import FakeModel
Expand All @@ -22,9 +23,11 @@
class RunHooksForTests(RunHooks):
def __init__(self):
self.events: dict[str, int] = defaultdict(int)
self.tool_context_ids: list[str] = []

def reset(self):
self.events.clear()
self.tool_context_ids.clear()

async def on_agent_start(
self, context: AgentHookContext[TContext], agent: Agent[TContext]
Expand Down Expand Up @@ -57,6 +60,8 @@ async def on_tool_end(
result: str,
) -> None:
self.events["on_tool_end"] += 1
if isinstance(context, ToolContext):
self.tool_context_ids.append(context.tool_call_id)

async def on_llm_start(
self,
Expand Down