Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 76 additions & 23 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import contextlib
import dataclasses
import inspect
from collections.abc import Awaitable
Expand Down Expand Up @@ -226,6 +227,29 @@ def get_model_tracing_impl(
return ModelTracing.ENABLED_WITHOUT_DATA


# Helpers for cancellable tool execution


async def _await_cancellable(awaitable):
"""Await an awaitable in its own task so CancelledError interrupts promptly."""
task = asyncio.create_task(awaitable)
try:
return await task
except asyncio.CancelledError:
# propagate so run.py can handle terminal cancel
raise


def _maybe_call_cancel_hook(tool_obj) -> None:
"""Best-effort: call a cancel/terminate hook on the tool if present."""
for name in ("cancel", "terminate", "stop"):
cb = getattr(tool_obj, name, None)
if callable(cb):
with contextlib.suppress(Exception):
cb()
break


class RunImpl:
@classmethod
async def execute_tools_and_side_effects(
Expand Down Expand Up @@ -572,16 +596,24 @@ async def run_single_tool(
if config.trace_include_sensitive_data:
span_fn.span_data.input = tool_call.arguments
try:
_, _, result = await asyncio.gather(
# run start hooks first (don’t tie them to the cancellable task)
await asyncio.gather(
hooks.on_tool_start(tool_context, agent, func_tool),
(
agent.hooks.on_tool_start(tool_context, agent, func_tool)
if agent.hooks
else _coro.noop_coroutine()
),
func_tool.on_invoke_tool(tool_context, tool_call.arguments),
)

try:
result = await _await_cancellable(
func_tool.on_invoke_tool(tool_context, tool_call.arguments)
)
except asyncio.CancelledError:
_maybe_call_cancel_hook(func_tool)
raise

await asyncio.gather(
hooks.on_tool_end(tool_context, agent, func_tool, result),
(
Expand All @@ -590,6 +622,7 @@ async def run_single_tool(
else _coro.noop_coroutine()
),
)

except Exception as e:
_error_tracing.attach_error_to_current_span(
SpanError(
Expand Down Expand Up @@ -660,7 +693,6 @@ async def execute_computer_actions(
config: RunConfig,
) -> list[RunItem]:
results: list[RunItem] = []
# Need to run these serially, because each action can affect the computer state
for action in actions:
acknowledged: list[ComputerCallOutputAcknowledgedSafetyCheck] | None = None
if action.tool_call.pending_safety_checks and action.computer_tool.on_safety_check:
Expand All @@ -677,24 +709,28 @@ async def execute_computer_actions(
if ack:
acknowledged.append(
ComputerCallOutputAcknowledgedSafetyCheck(
id=check.id,
code=check.code,
message=check.message,
id=check.id, code=check.code, message=check.message
)
)
else:
raise UserError("Computer tool safety check was not acknowledged")

results.append(
await ComputerAction.execute(
agent=agent,
action=action,
hooks=hooks,
context_wrapper=context_wrapper,
config=config,
acknowledged_safety_checks=acknowledged,
try:
item = await _await_cancellable(
ComputerAction.execute(
agent=agent,
action=action,
hooks=hooks,
context_wrapper=context_wrapper,
config=config,
acknowledged_safety_checks=acknowledged,
)
)
)
except asyncio.CancelledError:
_maybe_call_cancel_hook(action.computer_tool)
raise

results.append(item)

return results

Expand Down Expand Up @@ -1068,16 +1104,23 @@ async def execute(
else cls._get_screenshot_sync(action.computer_tool.computer, action.tool_call)
)

_, _, output = await asyncio.gather(
# start hooks first
await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, action.computer_tool),
(
agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
if agent.hooks
else _coro.noop_coroutine()
),
output_func,
)

# run the action (screenshot/etc) in a cancellable task
try:
output = await _await_cancellable(output_func)
except asyncio.CancelledError:
_maybe_call_cancel_hook(action.computer_tool)
raise

# end hooks
await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output),
(
Expand Down Expand Up @@ -1185,10 +1228,20 @@ async def execute(
data=call.tool_call,
)
output = call.local_shell_tool.executor(request)
if inspect.isawaitable(output):
result = await output
else:
result = output
try:
if inspect.isawaitable(output):
result = await _await_cancellable(output)
else:
# If executor returns a sync result, just use it (can’t cancel mid-call)
result = output
except asyncio.CancelledError:
# Best-effort: if the executor or tool exposes a cancel/terminate, call it
_maybe_call_cancel_hook(call.local_shell_tool)
# If your executor returns a proc handle (common pattern), adddress it here if needed:
# with contextlib.suppress(Exception):
# proc.terminate(); await asyncio.wait_for(proc.wait(), 1.0)
# proc.kill()
raise

await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result),
Expand All @@ -1201,7 +1254,7 @@ async def execute(

return ToolCallOutputItem(
agent=agent,
output=output,
output=result,
raw_item={
"type": "local_shell_call_output",
"id": call.tool_call.call_id,
Expand Down
24 changes: 20 additions & 4 deletions src/agents/models/openai_responses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import json
from collections.abc import AsyncIterator
from dataclasses import dataclass
Expand Down Expand Up @@ -175,15 +176,30 @@ async def stream_response(

final_response: Response | None = None

async for chunk in stream:
if isinstance(chunk, ResponseCompletedEvent):
final_response = chunk.response
yield chunk
try:
async for chunk in stream: # ensure type checkers relax here
if isinstance(chunk, ResponseCompletedEvent):
final_response = chunk.response
yield chunk
except asyncio.CancelledError:
# Cooperative cancel: ensure the HTTP stream is closed, then propagate
try:
await stream.close()
except Exception:
pass
raise
finally:
# Always close the stream if the async iterator exits (normal or error)
try:
await stream.close()
except Exception:
pass

if final_response and tracing.include_data():
span_response.span_data.response = final_response
span_response.span_data.input = input


except Exception as e:
span_response.set_error(
SpanError(
Expand Down
73 changes: 60 additions & 13 deletions src/agents/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import abc
import asyncio
import contextlib
from collections.abc import AsyncIterator
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, cast
Expand Down Expand Up @@ -143,6 +144,12 @@ class RunResultStreaming(RunResultBase):
is_complete: bool = False
"""Whether the agent has finished running."""

_emit_status_events: bool = False
"""Whether to emit RunUpdatedStreamEvent status updates.

Defaults to False for backward compatibility.
"""

# Queues that the background run_loop writes to
_event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] = field(
default_factory=asyncio.Queue, repr=False
Expand All @@ -164,17 +171,45 @@ def last_agent(self) -> Agent[Any]:
"""
return self.current_agent

def cancel(self) -> None:
"""Cancels the streaming run, stopping all background tasks and marking the run as
complete."""
self._cleanup_tasks() # Cancel all running tasks
self.is_complete = True # Mark the run as complete to stop event streaming
def cancel(self, reason: str | None = None) -> None:
# 1) Signal cooperative cancel to the runner
active = getattr(self, "_active_run", None)
if active:
with contextlib.suppress(Exception):
active.cancel(reason)
# 2) Do NOT cancel the background task; let the loop unwind cooperatively
# task = getattr(self, "_run_impl_task", None)
# if task and not task.done():
# with contextlib.suppress(Exception):
# task.cancel()

# 4) Mark complete; flushing only when status events are disabled
self.is_complete = True
if not getattr(self, "_emit_status_events", False):
with contextlib.suppress(Exception):
while not self._event_queue.empty():
self._event_queue.get_nowait()
self._event_queue.task_done()
with contextlib.suppress(Exception):
while not self._input_guardrail_queue.empty():
self._input_guardrail_queue.get_nowait()
self._input_guardrail_queue.task_done()

def inject(self, items: list[TResponseInputItem]) -> None:
"""
Inject new input items mid-run. They will be consumed at the start of the next step.
"""
active = getattr(self, "_active_run", None)
if active is not None:
try:
active.inject(items)
except Exception:
pass

# Optionally, clear the event queue to prevent processing stale events
while not self._event_queue.empty():
self._event_queue.get_nowait()
while not self._input_guardrail_queue.empty():
self._input_guardrail_queue.get_nowait()
@property
def active_run(self):
"""Access the underlying ActiveRun handle (may be None early in startup)."""
return getattr(self, "_active_run", None)

async def stream_events(self) -> AsyncIterator[StreamEvent]:
"""Stream deltas for new items as they are generated. We're using the types from the
Expand Down Expand Up @@ -243,21 +278,33 @@ def _check_errors(self):
# Check the tasks for any exceptions
if self._run_impl_task and self._run_impl_task.done():
run_impl_exc = self._run_impl_task.exception()
if run_impl_exc and isinstance(run_impl_exc, Exception):
if (
run_impl_exc
and isinstance(run_impl_exc, Exception)
and not isinstance(run_impl_exc, asyncio.CancelledError)
):
if isinstance(run_impl_exc, AgentsException) and run_impl_exc.run_data is None:
run_impl_exc.run_data = self._create_error_details()
self._stored_exception = run_impl_exc

if self._input_guardrails_task and self._input_guardrails_task.done():
in_guard_exc = self._input_guardrails_task.exception()
if in_guard_exc and isinstance(in_guard_exc, Exception):
if (
in_guard_exc
and isinstance(in_guard_exc, Exception)
and not isinstance(in_guard_exc, asyncio.CancelledError)
):
if isinstance(in_guard_exc, AgentsException) and in_guard_exc.run_data is None:
in_guard_exc.run_data = self._create_error_details()
self._stored_exception = in_guard_exc

if self._output_guardrails_task and self._output_guardrails_task.done():
out_guard_exc = self._output_guardrails_task.exception()
if out_guard_exc and isinstance(out_guard_exc, Exception):
if (
out_guard_exc
and isinstance(out_guard_exc, Exception)
and not isinstance(out_guard_exc, asyncio.CancelledError)
):
if isinstance(out_guard_exc, AgentsException) and out_guard_exc.run_data is None:
out_guard_exc.run_data = self._create_error_details()
self._stored_exception = out_guard_exc
Expand Down
Loading