From 15a304514b770eb0110a354c4aa4d30ff140298c Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Tue, 17 Mar 2026 09:11:21 +0900 Subject: [PATCH 1/2] fix: stop streamed tool execution after known input guardrail tripwire --- src/agents/result.py | 1 + src/agents/run_internal/guardrails.py | 1 + src/agents/run_internal/run_loop.py | 23 ++++++++- src/agents/run_internal/turn_resolution.py | 4 ++ tests/test_guardrails.py | 56 ++++++++++++++++++++++ 5 files changed, 84 insertions(+), 1 deletion(-) diff --git a/src/agents/result.py b/src/agents/result.py index 774c90dc4e..559504aa4b 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -454,6 +454,7 @@ class RunResultStreaming(RunResultBase): # Store the asyncio tasks that we're waiting on run_loop_task: asyncio.Task[Any] | None = field(default=None, repr=False) _input_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False) + _triggered_input_guardrail_result: InputGuardrailResult | None = field(default=None, repr=False) _output_guardrails_task: asyncio.Task[Any] | None = field(default=None, repr=False) _stored_exception: Exception | None = field(default=None, repr=False) _cancel_mode: Literal["none", "immediate", "after_turn"] = field(default="none", repr=False) diff --git a/src/agents/run_internal/guardrails.py b/src/agents/run_internal/guardrails.py index 375cc37c25..78b372f94f 100644 --- a/src/agents/run_internal/guardrails.py +++ b/src/agents/run_internal/guardrails.py @@ -84,6 +84,7 @@ async def run_input_guardrails_with_queue( }, ), ) + streamed_result._triggered_input_guardrail_result = result queue.put_nowait(result) guardrail_results.append(result) break diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 3d21d89fda..3abf337d28 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -1014,7 +1014,9 @@ async def _save_stream_items_without_count( streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) break except Exception as e: - if current_span and not isinstance(e, ModelBehaviorError): + if current_span and not isinstance( + e, (ModelBehaviorError, InputGuardrailTripwireTriggered) + ): _error_tracing.attach_error_to_span( current_span, SpanError( @@ -1100,6 +1102,24 @@ async def run_single_turn_streamed( reasoning_item_id_policy: ReasoningItemIdPolicy | None = None, ) -> SingleStepResult: """Run a single streamed turn and emit events as results arrive.""" + + async def raise_if_input_guardrail_tripwire_known() -> None: + tripwire_result = streamed_result._triggered_input_guardrail_result + if tripwire_result is not None: + raise InputGuardrailTripwireTriggered(tripwire_result) + + task = streamed_result._input_guardrails_task + if task is None or not task.done(): + return + + guardrail_exception = task.exception() + if guardrail_exception is not None: + raise guardrail_exception + + tripwire_result = streamed_result._triggered_input_guardrail_result + if tripwire_result is not None: + raise InputGuardrailTripwireTriggered(tripwire_result) + emitted_tool_call_ids: set[str] = set() emitted_reasoning_item_ids: set[str] = set() emitted_tool_search_fingerprints: set[str] = set() @@ -1433,6 +1453,7 @@ async def rewind_model_request() -> None: run_config=run_config, tool_use_tracker=tool_use_tracker, event_queue=streamed_result._event_queue, + before_side_effects=raise_if_input_guardrail_tripwire_known, ) items_to_filter = session_items_for_turn(single_step_result) diff --git a/src/agents/run_internal/turn_resolution.py b/src/agents/run_internal/turn_resolution.py index 2b3f98b55b..e715854f48 100644 --- a/src/agents/run_internal/turn_resolution.py +++ b/src/agents/run_internal/turn_resolution.py @@ -1696,6 +1696,7 @@ async def get_single_step_result_from_response( run_config: RunConfig, tool_use_tracker, event_queue: asyncio.Queue[StreamEvent | QueueCompleteSentinel] | None = None, + before_side_effects: Callable[[], Awaitable[None]] | None = None, ) -> SingleStepResult: processed_response = process_model_response( agent=agent, @@ -1706,6 +1707,9 @@ async def get_single_step_result_from_response( existing_items=pre_step_items, ) + if before_side_effects is not None: + await before_side_effects() + tool_use_tracker.record_processed_response(agent, processed_response) if event_queue is not None and processed_response.new_items: diff --git a/tests/test_guardrails.py b/tests/test_guardrails.py index 1b7d4a4225..6ec0ceed5e 100644 --- a/tests/test_guardrails.py +++ b/tests/test_guardrails.py @@ -658,6 +658,62 @@ async def slow_parallel_check( assert model.first_turn_args is not None, "Model should have been called in parallel mode" +@pytest.mark.asyncio +async def test_parallel_guardrail_trip_before_tool_execution_stops_streaming_turn(): + tool_was_executed = False + model_started = asyncio.Event() + guardrail_tripped = asyncio.Event() + + @function_tool + def dangerous_tool() -> str: + nonlocal tool_was_executed + tool_was_executed = True + return "tool_executed" + + @input_guardrail(run_in_parallel=True) + async def tripwire_before_tool_execution( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + await asyncio.wait_for(model_started.wait(), timeout=1) + guardrail_tripped.set() + return GuardrailFunctionOutput( + output_info="parallel_trip_before_tool_execution", + tripwire_triggered=True, + ) + + model = FakeModel() + original_stream_response = model.stream_response + + async def delayed_stream_response(*args, **kwargs): + model_started.set() + await asyncio.wait_for(guardrail_tripped.wait(), timeout=1) + await asyncio.sleep(SHORT_DELAY) + async for event in original_stream_response(*args, **kwargs): + yield event + + agent = Agent( + name="streaming_guardrail_hardening_agent", + instructions="Call the dangerous_tool immediately", + tools=[dangerous_tool], + input_guardrails=[tripwire_before_tool_execution], + model=model, + ) + model.set_next_output([get_function_tool_call("dangerous_tool", arguments="{}")]) + model.set_next_output([get_text_message("done")]) + + with patch.object(model, "stream_response", side_effect=delayed_stream_response): + result = Runner.run_streamed(agent, "trigger guardrail") + + with pytest.raises(InputGuardrailTripwireTriggered): + async for _event in result.stream_events(): + pass + + assert model_started.is_set() is True + assert guardrail_tripped.is_set() is True + assert tool_was_executed is False + assert model.first_turn_args is not None, "Model should have been called in parallel mode" + + @pytest.mark.asyncio async def test_blocking_guardrail_prevents_tool_execution(): tool_was_executed = False From 60856f69551202937706f3c4124c3623298831f4 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Tue, 17 Mar 2026 09:39:21 +0900 Subject: [PATCH 2/2] fix review comments --- src/agents/run_internal/guardrails.py | 2 +- tests/test_guardrails.py | 76 +++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/src/agents/run_internal/guardrails.py b/src/agents/run_internal/guardrails.py index 78b372f94f..59e008180c 100644 --- a/src/agents/run_internal/guardrails.py +++ b/src/agents/run_internal/guardrails.py @@ -71,6 +71,7 @@ async def run_input_guardrails_with_queue( for done in asyncio.as_completed(guardrail_tasks): result = await done if result.output.tripwire_triggered: + streamed_result._triggered_input_guardrail_result = result for t in guardrail_tasks: t.cancel() await asyncio.gather(*guardrail_tasks, return_exceptions=True) @@ -84,7 +85,6 @@ async def run_input_guardrails_with_queue( }, ), ) - streamed_result._triggered_input_guardrail_result = result queue.put_nowait(result) guardrail_results.append(result) break diff --git a/tests/test_guardrails.py b/tests/test_guardrails.py index 6ec0ceed5e..0789bdcfa0 100644 --- a/tests/test_guardrails.py +++ b/tests/test_guardrails.py @@ -714,6 +714,82 @@ async def delayed_stream_response(*args, **kwargs): assert model.first_turn_args is not None, "Model should have been called in parallel mode" +@pytest.mark.asyncio +async def test_parallel_guardrail_trip_with_slow_cancel_sibling_stops_streaming_turn(): + tool_was_executed = False + model_started = asyncio.Event() + guardrail_tripped = asyncio.Event() + slow_cancel_started = asyncio.Event() + slow_cancel_finished = asyncio.Event() + + @function_tool + def dangerous_tool() -> str: + nonlocal tool_was_executed + tool_was_executed = True + return "tool_executed" + + @input_guardrail(run_in_parallel=True) + async def tripwire_before_tool_execution( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + await asyncio.wait_for(model_started.wait(), timeout=1) + guardrail_tripped.set() + return GuardrailFunctionOutput( + output_info="parallel_trip_before_tool_execution_with_slow_cancel", + tripwire_triggered=True, + ) + + @input_guardrail(run_in_parallel=True) + async def slow_to_cancel_guardrail( + ctx: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem] + ) -> GuardrailFunctionOutput: + try: + await asyncio.Event().wait() + return GuardrailFunctionOutput( + output_info="slow_to_cancel_guardrail_completed", + tripwire_triggered=False, + ) + except asyncio.CancelledError: + slow_cancel_started.set() + await asyncio.sleep(SHORT_DELAY) + slow_cancel_finished.set() + raise + + model = FakeModel() + original_stream_response = model.stream_response + + async def delayed_stream_response(*args, **kwargs): + model_started.set() + await asyncio.wait_for(guardrail_tripped.wait(), timeout=1) + await asyncio.wait_for(slow_cancel_started.wait(), timeout=1) + async for event in original_stream_response(*args, **kwargs): + yield event + + agent = Agent( + name="streaming_guardrail_slow_cancel_agent", + instructions="Call the dangerous_tool immediately", + tools=[dangerous_tool], + input_guardrails=[tripwire_before_tool_execution, slow_to_cancel_guardrail], + model=model, + ) + model.set_next_output([get_function_tool_call("dangerous_tool", arguments="{}")]) + model.set_next_output([get_text_message("done")]) + + with patch.object(model, "stream_response", side_effect=delayed_stream_response): + result = Runner.run_streamed(agent, "trigger guardrail") + + with pytest.raises(InputGuardrailTripwireTriggered): + async for _event in result.stream_events(): + pass + + assert model_started.is_set() is True + assert guardrail_tripped.is_set() is True + assert slow_cancel_started.is_set() is True + assert slow_cancel_finished.is_set() is True + assert tool_was_executed is False + assert model.first_turn_args is not None, "Model should have been called in parallel mode" + + @pytest.mark.asyncio async def test_blocking_guardrail_prevents_tool_execution(): tool_was_executed = False