From 1e8b70a05e10d10a87f68b6f7fa3b2a010d0797e Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Tue, 17 Mar 2026 15:27:43 +0900 Subject: [PATCH 1/5] feat: add approval override argument parity for RunState resumes --- src/agents/memory/__init__.py | 16 + .../memory/openai_conversations_session.py | 1 + .../openai_responses_compaction_session.py | 151 ++++++- src/agents/memory/session.py | 102 +++++ src/agents/memory/sqlite_session.py | 67 ++- src/agents/result.py | 16 + src/agents/run.py | 48 ++- .../run_internal/agent_runner_helpers.py | 46 +- src/agents/run_internal/run_loop.py | 1 + .../run_internal/session_persistence.py | 140 ++++-- src/agents/run_state.py | 398 +++++++++++++++++- ...est_openai_responses_compaction_session.py | 175 +++++++- tests/test_agent_runner.py | 154 ++++++- tests/test_agent_runner_streamed.py | 74 +++- tests/test_agent_tracing.py | 96 +++++ tests/test_run_state.py | 155 ++++++- tests/utils/simple_session.py | 20 +- 17 files changed, 1592 insertions(+), 68 deletions(-) diff --git a/src/agents/memory/__init__.py b/src/agents/memory/__init__.py index 909a907134..fb58eadb57 100644 --- a/src/agents/memory/__init__.py +++ b/src/agents/memory/__init__.py @@ -1,11 +1,19 @@ from .openai_conversations_session import OpenAIConversationsSession from .openai_responses_compaction_session import OpenAIResponsesCompactionSession from .session import ( + SERVER_MANAGED_CONVERSATION_SESSION_ATTR, OpenAIResponsesCompactionArgs, OpenAIResponsesCompactionAwareSession, + ServerManagedConversationSession, Session, SessionABC, + SessionHistoryMutation, + SessionHistoryRewriteArgs, + SessionHistoryRewriteAwareSession, + apply_session_history_mutations, is_openai_responses_compaction_aware_session, + is_server_managed_conversation_session, + is_session_history_rewrite_aware_session, ) from .session_settings import SessionSettings from .sqlite_session import SQLiteSession @@ -21,5 +29,13 @@ "OpenAIResponsesCompactionSession", "OpenAIResponsesCompactionArgs", "OpenAIResponsesCompactionAwareSession", + "SERVER_MANAGED_CONVERSATION_SESSION_ATTR", + "SessionHistoryMutation", + "SessionHistoryRewriteArgs", + "SessionHistoryRewriteAwareSession", + "ServerManagedConversationSession", + "apply_session_history_mutations", + "is_server_managed_conversation_session", "is_openai_responses_compaction_aware_session", + "is_session_history_rewrite_aware_session", ] diff --git a/src/agents/memory/openai_conversations_session.py b/src/agents/memory/openai_conversations_session.py index 4d4fbaf635..5bd9288683 100644 --- a/src/agents/memory/openai_conversations_session.py +++ b/src/agents/memory/openai_conversations_session.py @@ -21,6 +21,7 @@ async def start_openai_conversations_session(openai_client: AsyncOpenAI | None = class OpenAIConversationsSession(SessionABC): + _server_managed_conversation_session = True session_settings: SessionSettings | None = None def __init__( diff --git a/src/agents/memory/openai_responses_compaction_session.py b/src/agents/memory/openai_responses_compaction_session.py index e2148f4868..a41f78a8f6 100644 --- a/src/agents/memory/openai_responses_compaction_session.py +++ b/src/agents/memory/openai_responses_compaction_session.py @@ -11,6 +11,9 @@ OpenAIResponsesCompactionArgs, OpenAIResponsesCompactionAwareSession, SessionABC, + SessionHistoryRewriteArgs, + apply_session_history_mutations, + is_session_history_rewrite_aware_session, ) if TYPE_CHECKING: @@ -129,7 +132,10 @@ def __init__( self._session_items: list[TResponseInputItem] | None = None self._response_id: str | None = None self._deferred_response_id: str | None = None - self._last_unstored_response_id: str | None = None + self._last_store: bool | None = None + self._has_pending_local_history_rewrite = False + self._local_history_rewrite_response_id: str | None = None + self._has_unacknowledged_local_session_adds = False @property def client(self) -> AsyncOpenAI: @@ -137,40 +143,76 @@ def client(self) -> AsyncOpenAI: self._client = get_default_openai_client() or AsyncOpenAI() return self._client - def _resolve_compaction_mode_for_response( + def _resolve_compaction_mode( self, *, + requested_mode: OpenAIResponsesCompactionMode, response_id: str | None, store: bool | None, - requested_mode: OpenAIResponsesCompactionMode | None, + turn_has_local_adds_without_new_response_id: bool, ) -> _ResolvedCompactionMode: - mode = requested_mode or self.compaction_mode + resolved_mode = _resolve_compaction_mode( + requested_mode, + response_id=response_id, + store=store, + ) + + if turn_has_local_adds_without_new_response_id and resolved_mode == "previous_response_id": + self._has_unacknowledged_local_session_adds = False + self._mark_local_history_rewrite() + logger.debug( + "compact: forcing input mode after local session delta without new response id" + ) + return "input" + + if not self._has_pending_local_history_rewrite: + return resolved_mode + if ( - mode == "auto" - and store is None + self._local_history_rewrite_response_id is not None and response_id is not None - and response_id == self._last_unstored_response_id + and response_id != self._local_history_rewrite_response_id ): + self._has_pending_local_history_rewrite = False + self._local_history_rewrite_response_id = None + return resolved_mode + + if resolved_mode == "previous_response_id": + if self._local_history_rewrite_response_id is None and response_id is not None: + self._local_history_rewrite_response_id = response_id + logger.debug("compact: forcing input mode after local history rewrite") return "input" - return _resolve_compaction_mode(mode, response_id=response_id, store=store) + + return resolved_mode async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None) -> None: """Run compaction using responses.compact API.""" + previous_response_id = self._response_id if args and args.get("response_id"): self._response_id = args["response_id"] requested_mode = args.get("compaction_mode") if args else None if args and "store" in args: - store = args["store"] - if store is False and self._response_id: - self._last_unstored_response_id = self._response_id - elif store is True and self._response_id == self._last_unstored_response_id: - self._last_unstored_response_id = None + store: bool | None = args["store"] + self._last_store = store else: - store = None - resolved_mode = self._resolve_compaction_mode_for_response( + store = self._last_store + turn_has_local_adds_without_new_response_id = ( + self._has_unacknowledged_local_session_adds + and (args is None or args.get("response_id") in {None, previous_response_id}) + ) + if ( + args + and args.get("response_id") is not None + and args["response_id"] != previous_response_id + ): + self._has_unacknowledged_local_session_adds = False + resolved_mode = self._resolve_compaction_mode( response_id=self._response_id, store=store, - requested_mode=requested_mode, + requested_mode=requested_mode or self.compaction_mode, + turn_has_local_adds_without_new_response_id=( + turn_has_local_adds_without_new_response_id + ), ) if resolved_mode == "previous_response_id" and not self._response_id: @@ -198,6 +240,15 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None ) return + unresolved_function_calls = _find_unresolved_function_calls_without_results(session_items) + if unresolved_function_calls: + logger.debug( + "compact: blocked unresolved function calls for %s: %s", + self._response_id, + unresolved_function_calls, + ) + return + self._deferred_response_id = None logger.debug( f"compact: start for {self._response_id} using {self.model} (mode={resolved_mode})" @@ -239,14 +290,37 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return await self.underlying_session.get_items(limit) + async def apply_history_mutations(self, args: SessionHistoryRewriteArgs) -> None: + """Rewrite persisted history and keep compaction caches aligned with the new transcript.""" + mutations = list(args.get("mutations", [])) + if not mutations: + return + + if is_session_history_rewrite_aware_session(self.underlying_session): + await self.underlying_session.apply_history_mutations({"mutations": mutations}) + await self._refresh_caches_from_underlying_session() + self._mark_local_history_rewrite() + return + + rewritten_items = apply_session_history_mutations( + await self.underlying_session.get_items(), + mutations, + ) + await self.underlying_session.clear_session() + if rewritten_items: + await self.underlying_session.add_items(rewritten_items) + self._session_items = rewritten_items + self._compaction_candidate_items = select_compaction_candidate_items(rewritten_items) + self._mark_local_history_rewrite() + async def _defer_compaction(self, response_id: str, store: bool | None = None) -> None: if self._deferred_response_id is not None: return compaction_candidate_items, session_items = await self._ensure_compaction_candidates() - resolved_mode = self._resolve_compaction_mode_for_response( + resolved_mode = _resolve_compaction_mode( + self.compaction_mode, response_id=response_id, - store=store, - requested_mode=None, + store=store if store is not None else self._last_store, ) should_compact = self.should_trigger_compaction( { @@ -266,7 +340,10 @@ def _clear_deferred_compaction(self) -> None: self._deferred_response_id = None async def add_items(self, items: list[TResponseInputItem]) -> None: + if not items: + return await self.underlying_session.add_items(items) + self._has_unacknowledged_local_session_adds = True if self._compaction_candidate_items is not None: new_candidates = select_compaction_candidate_items(items) if new_candidates: @@ -286,6 +363,15 @@ async def clear_session(self) -> None: self._compaction_candidate_items = [] self._session_items = [] self._deferred_response_id = None + self._has_pending_local_history_rewrite = False + self._local_history_rewrite_response_id = None + self._has_unacknowledged_local_session_adds = False + self._last_store = None + + async def _refresh_caches_from_underlying_session(self) -> None: + history = await self.underlying_session.get_items() + self._session_items = history + self._compaction_candidate_items = select_compaction_candidate_items(history) async def _ensure_compaction_candidates( self, @@ -304,10 +390,37 @@ async def _ensure_compaction_candidates( ) return (candidates[:], history[:]) + def _mark_local_history_rewrite(self) -> None: + self._has_pending_local_history_rewrite = True + self._local_history_rewrite_response_id = self._response_id + _ResolvedCompactionMode = Literal["previous_response_id", "input"] +def _find_unresolved_function_calls_without_results(items: list[TResponseInputItem]) -> list[str]: + """Return function-call ids that do not yet have matching outputs.""" + function_calls: dict[str, TResponseInputItem] = {} + resolved_call_ids: set[str] = set() + + for item in items: + if isinstance(item, dict): + item_type = item.get("type") + call_id = item.get("call_id") + else: + item_type = getattr(item, "type", None) + call_id = getattr(item, "call_id", None) + + if not isinstance(call_id, str): + continue + if item_type == "function_call": + function_calls[call_id] = item + elif item_type == "function_call_output": + resolved_call_ids.add(call_id) + + return [call_id for call_id in function_calls if call_id not in resolved_call_ids] + + def _resolve_compaction_mode( requested_mode: OpenAIResponsesCompactionMode, *, diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 85a65a1690..df9a864e06 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable @@ -9,6 +10,8 @@ from ..items import TResponseInputItem from .session_settings import SessionSettings +SERVER_MANAGED_CONVERSATION_SESSION_ATTR = "_server_managed_conversation_session" + @runtime_checkable class Session(Protocol): @@ -104,6 +107,105 @@ async def clear_session(self) -> None: ... +@runtime_checkable +class ServerManagedConversationSession(Session, Protocol): + """Protocol for sessions whose canonical history is managed by a remote service.""" + + _server_managed_conversation_session: Literal[True] + + +def is_server_managed_conversation_session( + session: Session | None, +) -> TypeGuard[ServerManagedConversationSession]: + """Check whether the session advertises server-managed history semantics.""" + if session is None: + return False + try: + marker = getattr(session, SERVER_MANAGED_CONVERSATION_SESSION_ATTR, False) + except Exception: + return False + return marker is True + + +class ReplaceFunctionCallSessionHistoryMutation(TypedDict): + """Replace the canonical persisted function call for a tool call.""" + + type: Literal["replace_function_call"] + call_id: str + replacement: TResponseInputItem + + +SessionHistoryMutation = ReplaceFunctionCallSessionHistoryMutation + + +class SessionHistoryRewriteArgs(TypedDict): + """Arguments for persisted-history rewrites.""" + + mutations: list[SessionHistoryMutation] + + +@runtime_checkable +class SessionHistoryRewriteAwareSession(Session, Protocol): + """Protocol for sessions that can rewrite previously persisted history.""" + + async def apply_history_mutations(self, args: SessionHistoryRewriteArgs) -> None: + """Apply structured history mutations to the persisted session history.""" + ... + + +def is_session_history_rewrite_aware_session( + session: Session | None, +) -> TypeGuard[SessionHistoryRewriteAwareSession]: + """Check whether a session supports persisted-history rewrites.""" + if session is None: + return False + try: + apply_history_mutations = getattr(session, "apply_history_mutations", None) + except Exception: + return False + return callable(apply_history_mutations) + + +def apply_session_history_mutations( + items: list[TResponseInputItem], + mutations: list[SessionHistoryMutation], +) -> list[TResponseInputItem]: + """Apply structured history mutations to a list of persisted session items.""" + next_items = [copy.deepcopy(item) for item in items] + for mutation in mutations: + if mutation["type"] == "replace_function_call": + next_items = _apply_replace_function_call_mutation(next_items, mutation) + return next_items + + +def _apply_replace_function_call_mutation( + items: list[TResponseInputItem], + mutation: ReplaceFunctionCallSessionHistoryMutation, +) -> list[TResponseInputItem]: + """Replace the first matching function call and drop later duplicates for the same call id.""" + replacement = copy.deepcopy(mutation["replacement"]) + next_items: list[TResponseInputItem] = [] + kept_replacement = False + + for item in items: + if _is_matching_function_call(item, mutation["call_id"]): + if not kept_replacement: + next_items.append(replacement) + kept_replacement = True + continue + next_items.append(item) + + return next_items + + +def _is_matching_function_call(item: TResponseInputItem, call_id: str) -> bool: + if isinstance(item, dict): + return item.get("type") == "function_call" and item.get("call_id") == call_id + item_type = getattr(item, "type", None) + item_call_id = getattr(item, "call_id", None) + return item_type == "function_call" and item_call_id == call_id + + class OpenAIResponsesCompactionArgs(TypedDict, total=False): """Arguments for the run_compaction method.""" diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 92c9630c9b..be1aa97e81 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -1,13 +1,14 @@ from __future__ import annotations import asyncio +import copy import json import sqlite3 import threading from pathlib import Path from ..items import TResponseInputItem -from .session import SessionABC +from .session import SessionABC, SessionHistoryRewriteArgs, apply_session_history_mutations from .session_settings import SessionSettings, resolve_session_limit @@ -273,6 +274,70 @@ def _clear_session_sync(): await asyncio.to_thread(_clear_session_sync) + async def apply_history_mutations(self, args: SessionHistoryRewriteArgs) -> None: + """Rewrite persisted session history using structured mutations.""" + mutations = list(args.get("mutations", [])) + if not mutations: + return + + def _apply_history_mutations_sync() -> None: + conn = self._get_connection() + with self._lock if self._is_memory_db else threading.Lock(): + cursor = conn.execute( + f""" + SELECT message_data FROM {self.messages_table} + WHERE session_id = ? + ORDER BY id ASC + """, + (self.session_id,), + ) + rows = cursor.fetchall() + + existing_items: list[TResponseInputItem] = [] + for (message_data,) in rows: + try: + item = json.loads(message_data) + except json.JSONDecodeError: + continue + existing_items.append(copy.deepcopy(item)) + + rewritten_items = apply_session_history_mutations(existing_items, mutations) + + conn.execute( + f""" + DELETE FROM {self.messages_table} + WHERE session_id = ? + """, + (self.session_id,), + ) + + if rewritten_items: + conn.execute( + f""" + INSERT OR IGNORE INTO {self.sessions_table} (session_id) VALUES (?) + """, + (self.session_id,), + ) + message_data = [(self.session_id, json.dumps(item)) for item in rewritten_items] + conn.executemany( + f""" + INSERT INTO {self.messages_table} (session_id, message_data) VALUES (?, ?) + """, + message_data, + ) + + conn.execute( + f""" + UPDATE {self.sessions_table} + SET updated_at = CURRENT_TIMESTAMP + WHERE session_id = ? + """, + (self.session_id,), + ) + conn.commit() + + await asyncio.to_thread(_apply_history_mutations_sync) + def close(self) -> None: """Close the database connection.""" if self._is_memory_db: diff --git a/src/agents/result.py b/src/agents/result.py index 774c90dc4e..355dab06e7 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -78,6 +78,7 @@ def _populate_state_from_result( auto_previous_response_id: bool = False, ) -> RunState[Any]: """Populate a RunState with common fields from a RunResult.""" + source_state = getattr(result, "_state", None) model_input_items = getattr(result, "_model_input_items", None) if isinstance(model_input_items, list): state._generated_items = list(model_input_items) @@ -106,6 +107,19 @@ def _populate_state_from_result( if trace_state is None: trace_state = TraceState.from_trace(getattr(result, "trace", None)) state._trace_state = copy.deepcopy(trace_state) if trace_state else None + state._trace_include_sensitive_data = getattr( + source_state, + "_trace_include_sensitive_data", + True, + ) + if isinstance(source_state, RunState): + state._session_history_mutations = source_state.get_session_history_mutations() + state._execution_only_approval_override_call_ids = list( + source_state._execution_only_approval_override_call_ids + ) + else: + state._session_history_mutations = [] + state._execution_only_approval_override_call_ids = [] return state @@ -316,6 +330,8 @@ class RunResult(RunResultBase): """The original input for the current run segment. This is updated when handoffs or resume logic replace the input history, and used by to_state() to preserve the correct originalInput when serializing state.""" + _state: Any = field(default=None, repr=False) + """Internal reference to the originating RunState when available.""" _conversation_id: str | None = field(default=None, repr=False) """Conversation identifier for server-managed runs.""" _previous_response_id: str | None = field(default=None, repr=False) diff --git a/src/agents/run.py b/src/agents/run.py index 047d454d35..d073a3ec21 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -2,6 +2,7 @@ import asyncio import contextlib +import dataclasses import warnings from typing import Union, cast @@ -28,7 +29,7 @@ ) from .lifecycle import RunHooks from .logger import logger -from .memory import Session +from .memory import Session, is_server_managed_conversation_session from .result import RunResult, RunResultStreaming from .run_config import ( DEFAULT_MAX_TURNS, @@ -57,6 +58,7 @@ save_turn_items_if_needed, should_cancel_parallel_model_task_on_input_guardrail_trip, update_run_state_for_interruption, + validate_override_history_persistence_support, validate_session_conversation_settings, ) from .run_internal.approvals import approvals_from_step @@ -420,6 +422,7 @@ async def run( session_input_items_for_persistence: list[TResponseInputItem] | None = ( [] if (session is not None and is_resumed_state) else None ) + server_manages_conversation = False # Track the most recent input batch we persisted so conversation-lock retries can rewind # exactly those items (and not the full history). last_saved_input_snapshot_for_rewind: list[TResponseInputItem] | None = None @@ -493,6 +496,26 @@ async def run( ) original_input_for_state = prepared_input + server_manages_conversation = ( + conversation_id is not None + or previous_response_id is not None + or auto_previous_response_id + ) + history_is_server_managed = ( + server_manages_conversation or is_server_managed_conversation_session(session) + ) + validate_override_history_persistence_support( + input=input, + session=session, + history_is_server_managed=history_is_server_managed, + ) + + if is_resumed_state and run_state is not None: + run_config = dataclasses.replace( + run_config, + trace_include_sensitive_data=run_state._trace_include_sensitive_data, + ) + resolved_reasoning_item_id_policy: ReasoningItemIdPolicy | None = ( run_config.reasoning_item_id_policy if run_config.reasoning_item_id_policy is not None @@ -554,6 +577,7 @@ async def run( reattach_resumed_trace=is_resumed_state, ): if is_resumed_state and run_state is not None: + run_state.set_trace_include_sensitive_data(run_config.trace_include_sensitive_data) run_state.set_trace(get_current_trace()) current_turn = run_state._current_turn raw_original_input = run_state._original_input @@ -581,6 +605,7 @@ async def run( auto_previous_response_id=auto_previous_response_id, ) run_state._reasoning_item_id_policy = resolved_reasoning_item_id_policy + run_state.set_trace_include_sensitive_data(run_config.trace_include_sensitive_data) run_state.set_trace(get_current_trace()) def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: @@ -1450,6 +1475,7 @@ def run_streamed( run_state: RunState[TContext] | None = None input_for_result: str | list[TResponseInputItem] starting_input = input if not is_resumed_state else None + server_manages_conversation = False if is_resumed_state: run_state = cast(RunState[TContext], input) @@ -1518,6 +1544,25 @@ def run_streamed( auto_previous_response_id=auto_previous_response_id, ) + server_manages_conversation = ( + conversation_id is not None + or previous_response_id is not None + or auto_previous_response_id + ) + history_is_server_managed = ( + server_manages_conversation or is_server_managed_conversation_session(session) + ) + validate_override_history_persistence_support( + input=input, + session=session, + history_is_server_managed=history_is_server_managed, + ) + if is_resumed_state and run_state is not None: + run_config = dataclasses.replace( + run_config, + trace_include_sensitive_data=run_state._trace_include_sensitive_data, + ) + resolved_reasoning_item_id_policy: ReasoningItemIdPolicy | None = ( run_config.reasoning_item_id_policy if run_config.reasoning_item_id_policy is not None @@ -1548,6 +1593,7 @@ def run_streamed( reattach_resumed_trace=is_resumed_state, ) if run_state is not None: + run_state.set_trace_include_sensitive_data(run_config.trace_include_sensitive_data) run_state.set_trace(new_trace or get_current_trace()) schema_agent = ( diff --git a/src/agents/run_internal/agent_runner_helpers.py b/src/agents/run_internal/agent_runner_helpers.py index 776e406703..c40850e4d9 100644 --- a/src/agents/run_internal/agent_runner_helpers.py +++ b/src/agents/run_internal/agent_runner_helpers.py @@ -9,7 +9,10 @@ from ..exceptions import UserError from ..guardrail import InputGuardrailResult from ..items import ModelResponse, RunItem, ToolApprovalItem, TResponseInputItem -from ..memory import Session +from ..memory import ( + Session, + is_session_history_rewrite_aware_session, +) from ..result import RunResult from ..run_config import RunConfig from ..run_context import RunContextWrapper, TContext @@ -46,6 +49,7 @@ "save_turn_items_if_needed", "should_cancel_parallel_model_task_on_input_guardrail_trip", "update_run_state_for_interruption", + "validate_override_history_persistence_support", ] _PARALLEL_INPUT_GUARDRAIL_CANCEL_PATCH_ID = ( @@ -104,6 +108,45 @@ def validate_session_conversation_settings( ) +def validate_override_history_persistence_support( + *, + input: str | list[TResponseInputItem] | RunState[Any], + session: Session | None, + history_is_server_managed: bool, +) -> None: + """Fail fast when approval override persistence requirements are not satisfied.""" + if not isinstance(input, RunState): + return + + if input.has_pending_execution_only_approval_overrides() and not history_is_server_managed: + raise UserError( + "save_override_arguments=False is only supported when using conversation_id, " + "previous_response_id, auto_previous_response_id, or a server-managed session." + ) + + mutations = input.get_session_history_mutations() + if not mutations: + return + + if history_is_server_managed: + raise UserError( + "save_override_arguments requires local canonical history. " + "Server-managed conversations cannot persist corrected function_call arguments. " + "Pass save_override_arguments=False to apply the override only to the current " + "execution." + ) + + if session is None or is_session_history_rewrite_aware_session(session): + return + + raise UserError( + "save_override_arguments requires a session that supports persisted-history rewrites. " + "Use SQLiteSession, OpenAIResponsesCompactionSession, or another " + "SessionHistoryRewriteAwareSession, or pass save_override_arguments=False to apply " + "the override only to the current execution." + ) + + def resolve_trace_settings( *, run_state: RunState[TContext] | None, @@ -273,6 +316,7 @@ def build_interruption_result( result._model_input_items = list(generated_items) result._replay_from_model_input_items = list(generated_items) != list(session_items) if run_state is not None: + result._state = run_state result._current_turn_persisted_item_count = run_state._current_turn_persisted_item_count result._trace_state = run_state._trace_state result._original_input = copy_input_items(original_input) diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 3d21d89fda..87524f0a6d 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -477,6 +477,7 @@ def _sync_conversation_tracking_from_tracker() -> None: previous_response_id=previous_response_id, auto_previous_response_id=auto_previous_response_id, ) + run_state.set_trace_include_sensitive_data(run_config.trace_include_sensitive_data) run_state._reasoning_item_id_policy = resolved_reasoning_item_id_policy streamed_result._state = run_state elif streamed_result._state is None: diff --git a/src/agents/run_internal/session_persistence.py b/src/agents/run_internal/session_persistence.py index d63c5f0526..3cc95d7ddd 100644 --- a/src/agents/run_internal/session_persistence.py +++ b/src/agents/run_internal/session_persistence.py @@ -18,9 +18,11 @@ from ..memory import ( OpenAIResponsesCompactionArgs, Session, + SessionHistoryMutation, SessionInputCallback, SessionSettings, is_openai_responses_compaction_aware_session, + is_session_history_rewrite_aware_session, ) from ..memory.openai_conversations_session import OpenAIConversationsSession from ..run_state import RunState @@ -222,6 +224,7 @@ def update_run_state_after_resume( if session_items is not None: run_state._session_items = list(session_items) run_state._current_step = turn_result.next_step # type: ignore[assignment] + run_state.clear_execution_only_approval_overrides() async def save_result_to_session( @@ -244,6 +247,8 @@ async def save_result_to_session( already_persisted = run_state._current_turn_persisted_item_count if run_state else 0 if session is None: + if run_state is not None: + run_state.clear_session_history_mutations() return 0 new_run_items: list[RunItem] @@ -312,53 +317,98 @@ async def save_result_to_session( saved_run_items_count += 1 if len(items_to_save) == 0: + await _apply_session_history_mutations_on_session(session, run_state) + await _run_compaction_on_session( + session=session, + response_id=response_id, + new_items=new_items, + store=store, + ) if run_state: run_state._current_turn_persisted_item_count = already_persisted + saved_run_items_count return saved_run_items_count await session.add_items(items_to_save) + await _apply_session_history_mutations_on_session(session, run_state) if run_state: run_state._current_turn_persisted_item_count = already_persisted + saved_run_items_count - if response_id and is_openai_responses_compaction_aware_session(session): - has_local_tool_outputs = any( - isinstance(item, (ToolCallOutputItem, HandoffOutputItem)) for item in new_items - ) - if has_local_tool_outputs: - defer_compaction = getattr(session, "_defer_compaction", None) - if callable(defer_compaction): - result = defer_compaction(response_id, store=store) - if inspect.isawaitable(result): - await result - logger.debug( - "skip: deferring compaction for response %s due to local tool outputs", - response_id, - ) - return saved_run_items_count - - deferred_response_id = None - get_deferred = getattr(session, "_get_deferred_compaction_response_id", None) - if callable(get_deferred): - deferred_response_id = get_deferred() - force_compaction = deferred_response_id is not None - if force_compaction: - logger.debug( - "compact: forcing for response %s after deferred %s", - response_id, - deferred_response_id, - ) - compaction_args: OpenAIResponsesCompactionArgs = { - "response_id": response_id, - "force": force_compaction, - } - if store is not None: - compaction_args["store"] = store - await session.run_compaction(compaction_args) + await _run_compaction_on_session( + session=session, + response_id=response_id, + new_items=new_items, + store=store, + ) return saved_run_items_count +async def _apply_session_history_mutations_on_session( + session: Session, + run_state: RunState | None, +) -> None: + """Apply pending history rewrites to the persisted session if supported.""" + if run_state is None: + return + + mutations = run_state.get_session_history_mutations() + if not mutations: + return + + normalized_mutations = _normalize_history_mutations_for_session_persistence(session, mutations) + + if is_session_history_rewrite_aware_session(session): + await session.apply_history_mutations({"mutations": normalized_mutations}) + + run_state.clear_session_history_mutations() + + +async def _run_compaction_on_session( + *, + session: Session, + response_id: str | None, + new_items: list[RunItem], + store: bool | None, +) -> None: + """Run session compaction hooks after persistence or approval-only mutation cycles.""" + if not response_id or not is_openai_responses_compaction_aware_session(session): + return + + has_local_tool_outputs = any( + isinstance(item, (ToolCallOutputItem, HandoffOutputItem)) for item in new_items + ) + if has_local_tool_outputs: + defer_compaction = getattr(session, "_defer_compaction", None) + if callable(defer_compaction): + result = defer_compaction(response_id, store=store) + if inspect.isawaitable(result): + await result + logger.debug( + "skip: deferring compaction for response %s due to local tool outputs", response_id + ) + return + + deferred_response_id = None + get_deferred = getattr(session, "_get_deferred_compaction_response_id", None) + if callable(get_deferred): + deferred_response_id = get_deferred() + force_compaction = deferred_response_id is not None + if force_compaction: + logger.debug( + "compact: forcing for response %s after deferred %s", + response_id, + deferred_response_id, + ) + compaction_args: OpenAIResponsesCompactionArgs = { + "response_id": response_id, + "force": force_compaction, + } + if store is not None: + compaction_args["store"] = store + await session.run_compaction(compaction_args) + + async def save_resumed_turn_items( *, session: Session | None, @@ -576,6 +626,28 @@ def _fingerprint_or_repr(item: TResponseInputItem, *, ignore_ids_for_matching: b ) +def _normalize_history_mutations_for_session_persistence( + session: Session, + mutations: list[SessionHistoryMutation], +) -> list[SessionHistoryMutation]: + """Normalize persisted-history mutations to the same session-safe item shape used on writes.""" + normalized: list[SessionHistoryMutation] = [] + for mutation in mutations: + if mutation["type"] != "replace_function_call": + continue + replacement = ensure_input_item_format(mutation["replacement"]) + if isinstance(session, OpenAIConversationsSession): + replacement = _sanitize_openai_conversation_item(replacement) + normalized.append( + { + "type": "replace_function_call", + "call_id": mutation["call_id"], + "replacement": replacement, + } + ) + return normalized + + def _session_item_key(item: Any) -> str: """Return a stable representation of a session item for comparison.""" try: diff --git a/src/agents/run_state.py b/src/agents/run_state.py index dcda9e073c..72bb10fbf2 100644 --- a/src/agents/run_state.py +++ b/src/agents/run_state.py @@ -72,6 +72,7 @@ coerce_tool_search_output_raw_item, ) from .logger import logger +from .memory import SessionHistoryMutation from .run_context import RunContextWrapper from .tool import ( ApplyPatchTool, @@ -91,7 +92,9 @@ ToolOutputGuardrail, ToolOutputGuardrailResult, ) -from .tracing.traces import Trace, TraceState +from .tracing import custom_span, get_current_span, get_current_trace +from .tracing.spans import Span +from .tracing.traces import Trace, TraceState, reattach_trace from .usage import deserialize_usage, serialize_usage from .util._json import _to_dump_compatible @@ -104,6 +107,8 @@ ProcessedResponse, ) +from .run_internal.items import ensure_input_item_format + TContext = TypeVar("TContext", default=Any) TAgent = TypeVar("TAgent", bound="Agent[Any]", default="Agent[Any]") ContextOverride = Union[Mapping[str, Any], RunContextWrapper[Any]] @@ -118,9 +123,9 @@ # 3. to_json() always emits CURRENT_SCHEMA_VERSION. # 4. Forward compatibility is intentionally fail-fast (older SDKs reject newer or unsupported # versions). -CURRENT_SCHEMA_VERSION = "1.6" +CURRENT_SCHEMA_VERSION = "1.7" SUPPORTED_SCHEMA_VERSIONS = frozenset( - {"1.0", "1.1", "1.2", "1.3", "1.4", "1.5", CURRENT_SCHEMA_VERSION} + {"1.0", "1.1", "1.2", "1.3", "1.4", "1.5", "1.6", CURRENT_SCHEMA_VERSION} ) _FUNCTION_OUTPUT_ADAPTER: TypeAdapter[FunctionCallOutput] = TypeAdapter(FunctionCallOutput) @@ -214,6 +219,15 @@ class RunState(Generic[TContext, TAgent]): _tool_use_tracker_snapshot: dict[str, list[str]] = field(default_factory=dict) """Serialized snapshot of the AgentToolUseTracker (agent name -> tools used).""" + _session_history_mutations: list[SessionHistoryMutation] = field(default_factory=list) + """Pending session history rewrites that must be applied after persistence.""" + + _execution_only_approval_override_call_ids: list[str] = field(default_factory=list) + """Function call ids whose approved argument overrides are execution-only.""" + + _trace_include_sensitive_data: bool = True + """Whether approval-override traces may include sensitive argument payloads.""" + _trace_state: TraceState | None = field(default=None, repr=False) """Serialized trace metadata for resuming tracing context.""" @@ -253,6 +267,9 @@ def __init__( self._generated_items_last_processed_marker = None self._current_turn_persisted_item_count = 0 self._tool_use_tracker_snapshot = {} + self._session_history_mutations = [] + self._execution_only_approval_override_call_ids = [] + self._trace_include_sensitive_data = True self._trace_state = None from .agent_tool_state import get_agent_tool_state_scope @@ -267,12 +284,246 @@ def get_interruptions(self) -> list[ToolApprovalItem]: return [] return self._current_step.interruptions - def approve(self, approval_item: ToolApprovalItem, always_approve: bool = False) -> None: + def approve( + self, + approval_item: ToolApprovalItem, + always_approve: bool = False, + *, + override_arguments: dict[str, Any] | None = None, + save_override_arguments: bool | None = None, + ) -> None: """Approve a tool call and rerun with this state to continue.""" if self._context is None: raise UserError("Cannot approve tool: RunState has no context") + + if save_override_arguments is not None and override_arguments is None: + raise UserError( + "save_override_arguments can only be used together with override_arguments." + ) + + if override_arguments is not None: + self._apply_approval_argument_override( + approval_item=approval_item, + override_arguments=override_arguments, + always_approve=always_approve, + save_override_arguments=save_override_arguments, + ) + self._context.approve_tool(approval_item, always_approve=always_approve) + def get_session_history_mutations(self) -> list[SessionHistoryMutation]: + """Return a defensive copy of pending persisted-history mutations.""" + return copy.deepcopy(self._session_history_mutations) + + def has_pending_execution_only_approval_overrides(self) -> bool: + """Return whether execution-only argument overrides still need a server-managed resume.""" + return len(self._execution_only_approval_override_call_ids) > 0 + + def clear_execution_only_approval_overrides(self) -> None: + """Clear execution-only override bookkeeping after the resumed execution completes.""" + self._execution_only_approval_override_call_ids = [] + + def clear_session_history_mutations(self) -> None: + """Clear pending persisted-history mutations after they are applied.""" + self._session_history_mutations = [] + + def _apply_approval_argument_override( + self, + *, + approval_item: ToolApprovalItem, + override_arguments: dict[str, Any], + always_approve: bool, + save_override_arguments: bool | None, + ) -> None: + if always_approve: + raise UserError("override_arguments cannot be used together with always_approve.") + + if not _is_function_call_raw_item(approval_item.raw_item): + raise UserError("override_arguments is only supported for function_call approvals.") + + if not isinstance(override_arguments, dict): + raise UserError("override_arguments must be a plain JSON object.") + + try: + serialized_arguments = json.dumps(override_arguments) + except Exception as exc: + raise UserError(f"override_arguments must be JSON serializable. {exc}") from exc + + if not isinstance(serialized_arguments, str): + raise UserError("override_arguments could not be serialized to JSON.") + + should_save_override_arguments = save_override_arguments is not False + has_server_managed_conversation = bool( + self._conversation_id or self._previous_response_id or self._auto_previous_response_id + ) + if should_save_override_arguments and has_server_managed_conversation: + raise UserError( + "save_override_arguments requires local canonical history. " + "Server-managed conversations cannot persist corrected function_call arguments. " + "Pass save_override_arguments=False to apply the override only to the current " + "execution." + ) + + original_arguments = _get_raw_item_arguments(approval_item.raw_item) + call_id = _get_raw_item_call_id(approval_item.raw_item) + if not isinstance(call_id, str): + raise UserError("override_arguments requires a function_call with a call_id.") + + updated_tool_call = _create_function_call_override( + approval_item.raw_item, + serialized_arguments, + ) + approval_item.raw_item = updated_tool_call + self._replace_function_call_in_interruptions(call_id, updated_tool_call) + if self._last_processed_response is not None: + for interruption in self._last_processed_response.interruptions: + if not _is_function_call_raw_item(interruption.raw_item): + continue + if _get_raw_item_call_id(interruption.raw_item) != call_id: + continue + interruption.raw_item = updated_tool_call + + if should_save_override_arguments: + self._clear_execution_only_approval_override_call_id(call_id) + else: + self._record_execution_only_approval_override(call_id) + + if self._last_processed_response is not None: + for function_run in self._last_processed_response.functions: + if _get_raw_item_call_id(function_run.tool_call) != call_id: + continue + function_run.tool_call = updated_tool_call + + if should_save_override_arguments: + self._replace_function_call_in_run_items( + self._last_processed_response.new_items, + call_id, + updated_tool_call, + ) + + if should_save_override_arguments: + self._replace_function_call_in_run_items( + self._generated_items, call_id, updated_tool_call + ) + self._replace_function_call_in_run_items( + self._session_items, call_id, updated_tool_call + ) + self._replace_function_call_in_model_responses(call_id, updated_tool_call) + self._record_session_history_mutation( + { + "type": "replace_function_call", + "call_id": call_id, + "replacement": ensure_input_item_format(updated_tool_call), + } + ) + self._mark_generated_items_merged_with_last_processed() + + self._record_approval_argument_override_trace( + tool_name=approval_item.name or _get_raw_item_name(updated_tool_call), + call_id=call_id, + original_arguments=original_arguments, + serialized_arguments=serialized_arguments, + ) + + def _replace_function_call_in_interruptions(self, call_id: str, tool_call: Any) -> None: + """Replace a function call inside pending approval interruptions.""" + for interruption in self.get_interruptions(): + if not _is_function_call_raw_item(interruption.raw_item): + continue + if _get_raw_item_call_id(interruption.raw_item) != call_id: + continue + interruption.raw_item = tool_call + + def _replace_function_call_in_run_items( + self, + items: list[RunItem], + call_id: str, + tool_call: Any, + ) -> None: + """Replace matching function-call raw items inside serialized run item history.""" + for item in items: + if not _is_function_call_raw_item(getattr(item, "raw_item", None)): + continue + if _get_raw_item_call_id(item.raw_item) != call_id: + continue + item.raw_item = tool_call + + def _replace_function_call_in_model_responses(self, call_id: str, tool_call: Any) -> None: + """Replace matching function calls in stored raw model responses.""" + for response in self._model_responses: + for index, item in enumerate(response.output): + if not _is_function_call_raw_item(item): + continue + if _get_raw_item_call_id(item) != call_id: + continue + response.output[index] = tool_call + + def _record_session_history_mutation(self, mutation: SessionHistoryMutation) -> None: + """Record or replace a pending persisted-history mutation for the same call id.""" + for index, existing in enumerate(self._session_history_mutations): + if existing["type"] != mutation["type"] or existing["call_id"] != mutation["call_id"]: + continue + self._session_history_mutations[index] = copy.deepcopy(mutation) + return + self._session_history_mutations.append(copy.deepcopy(mutation)) + + def _record_execution_only_approval_override(self, call_id: str) -> None: + """Remember that this call id was approved with execution-only argument overrides.""" + if call_id in self._execution_only_approval_override_call_ids: + return + self._execution_only_approval_override_call_ids.append(call_id) + + def _clear_execution_only_approval_override_call_id(self, call_id: str) -> None: + """Forget execution-only bookkeeping for a call id once durable history wins.""" + self._execution_only_approval_override_call_ids = [ + existing_call_id + for existing_call_id in self._execution_only_approval_override_call_ids + if existing_call_id != call_id + ] + + def _record_approval_argument_override_trace( + self, + *, + tool_name: str | None, + call_id: str, + original_arguments: Any, + serialized_arguments: str, + ) -> None: + """Emit an audit span describing the approval-time argument override.""" + if not self._trace_include_sensitive_data: + return + + parent: Trace | Span[Any] | None = get_current_span() or get_current_trace() + if parent is None and self._trace_state is not None: + parent = reattach_trace(self._trace_state) + if parent is None: + return + + resolved_tool_name = tool_name or "function_call" + try: + span = custom_span( + name=f"approval override: {resolved_tool_name}", + data={ + "tool_name": resolved_tool_name, + "call_id": call_id, + "original_arguments": _deserialize_function_call_arguments_for_trace( + original_arguments + ), + "override_arguments": _deserialize_function_call_arguments_for_trace( + serialized_arguments + ), + }, + parent=parent, + ) + span.start() + span.finish() + except Exception as exc: + logger.warning( + "Failed to record approval override trace for %s: %s", + resolved_tool_name, + exc, + ) + def reject( self, approval_item: ToolApprovalItem, @@ -657,6 +908,14 @@ def to_json( "previous_response_id": self._previous_response_id, "auto_previous_response_id": self._auto_previous_response_id, "reasoning_item_id_policy": self._reasoning_item_id_policy, + "execution_only_approval_override_call_ids": list( + self._execution_only_approval_override_call_ids + ), + "session_history_mutations": [ + _serialize_session_history_mutation(mutation) + for mutation in self._session_history_mutations + ], + "trace_include_sensitive_data": self._trace_include_sensitive_data, } generated_items = self._merge_generated_items_with_processed() @@ -864,6 +1123,10 @@ def set_trace(self, trace: Trace | None) -> None: """Capture trace metadata for serialization/resumption.""" self._trace_state = TraceState.from_trace(trace) + def set_trace_include_sensitive_data(self, include_sensitive_data: bool) -> None: + """Store whether resumed traces may include sensitive approval override data.""" + self._trace_include_sensitive_data = include_sensitive_data + def _serialize_trace_data(self, *, include_tracing_api_key: bool) -> dict[str, Any] | None: if not self._trace_state: return None @@ -1090,6 +1353,118 @@ def _serialize_raw_item_value(raw_item: Any) -> Any: return raw_item +def _serialize_session_history_mutation(mutation: SessionHistoryMutation) -> dict[str, Any]: + """Serialize a session history mutation into a JSON-compatible dictionary.""" + return { + "type": mutation["type"], + "call_id": mutation["call_id"], + "replacement": _serialize_raw_item_value(mutation["replacement"]), + } + + +def _deserialize_session_history_mutations( + serialized_mutations: Any, +) -> list[SessionHistoryMutation]: + """Deserialize persisted session history mutations from JSON data.""" + if not isinstance(serialized_mutations, Sequence) or isinstance( + serialized_mutations, (str, bytes) + ): + return [] + + mutations: list[SessionHistoryMutation] = [] + for mutation in serialized_mutations: + if not isinstance(mutation, Mapping): + continue + if mutation.get("type") != "replace_function_call": + continue + call_id = mutation.get("call_id") + replacement = mutation.get("replacement") + if not isinstance(call_id, str): + continue + normalized_replacement: Any + if isinstance(replacement, Mapping): + normalized_replacement = dict(replacement) + else: + normalized_replacement = replacement + mutations.append( + { + "type": "replace_function_call", + "call_id": call_id, + "replacement": cast(TResponseInputItem, normalized_replacement), + } + ) + return mutations + + +def _is_function_call_raw_item(raw_item: Any) -> bool: + """Return whether the raw item represents a function call.""" + if isinstance(raw_item, dict): + return raw_item.get("type") == "function_call" + return getattr(raw_item, "type", None) == "function_call" + + +def _get_raw_item_call_id(raw_item: Any) -> str | None: + """Return the call_id for a raw tool item when available.""" + if isinstance(raw_item, dict): + call_id = raw_item.get("call_id") or raw_item.get("callId") or raw_item.get("id") + return call_id if isinstance(call_id, str) else None + for attr in ("call_id", "callId", "id"): + value = getattr(raw_item, attr, None) + if isinstance(value, str): + return value + return None + + +def _get_raw_item_name(raw_item: Any) -> str | None: + """Return the tool name for a raw tool item when available.""" + if isinstance(raw_item, dict): + name = raw_item.get("name") + return name if isinstance(name, str) else None + name = getattr(raw_item, "name", None) + return name if isinstance(name, str) else None + + +def _get_raw_item_arguments(raw_item: Any) -> Any: + """Return the serialized arguments field for a raw function call when available.""" + if isinstance(raw_item, dict): + return raw_item.get("arguments") + return getattr(raw_item, "arguments", None) + + +def _create_function_call_override(raw_item: Any, serialized_arguments: str) -> Any: + """Return a copy of a function call raw item with corrected arguments.""" + if isinstance(raw_item, ResponseFunctionToolCall): + return raw_item.model_copy(update={"arguments": serialized_arguments}) + if isinstance(raw_item, dict): + updated = dict(raw_item) + updated["arguments"] = serialized_arguments + return updated + if hasattr(raw_item, "model_dump"): + try: + payload = raw_item.model_dump(exclude_unset=True) + except TypeError: + payload = raw_item.model_dump() + payload = dict(payload) + payload["arguments"] = serialized_arguments + try: + return ResponseFunctionToolCall(**payload) + except Exception: + return payload + raise UserError("override_arguments is only supported for function_call approvals.") + + +def _deserialize_function_call_arguments_for_trace(arguments: Any) -> Any: + """Decode serialized function-call arguments for trace payloads when possible.""" + if arguments is None: + return None + if not isinstance(arguments, str): + return arguments + try: + return json.loads(arguments) + except Exception: + return arguments + + def _ensure_json_compatible(value: Any) -> Any: try: return json.loads(json.dumps(value, default=str)) @@ -2288,6 +2663,21 @@ async def _build_run_state_from_json( state._reasoning_item_id_policy = cast(Literal["preserve", "omit"], serialized_policy) else: state._reasoning_item_id_policy = None + serialized_execution_only_call_ids = state_json.get( + "execution_only_approval_override_call_ids", [] + ) + if isinstance(serialized_execution_only_call_ids, Sequence) and not isinstance( + serialized_execution_only_call_ids, (str, bytes) + ): + state._execution_only_approval_override_call_ids = [ + call_id for call_id in serialized_execution_only_call_ids if isinstance(call_id, str) + ] + else: + state._execution_only_approval_override_call_ids = [] + state._session_history_mutations = _deserialize_session_history_mutations( + state_json.get("session_history_mutations", []) + ) + state._trace_include_sensitive_data = bool(state_json.get("trace_include_sensitive_data", True)) state.set_tool_use_tracker_snapshot(state_json.get("tool_use_tracker", {})) trace_data = state_json.get("trace") if isinstance(trace_data, Mapping): diff --git a/tests/memory/test_openai_responses_compaction_session.py b/tests/memory/test_openai_responses_compaction_session.py index 7af406a602..4e8a4c4ade 100644 --- a/tests/memory/test_openai_responses_compaction_session.py +++ b/tests/memory/test_openai_responses_compaction_session.py @@ -21,7 +21,7 @@ ) from tests.fake_model import FakeModel from tests.test_responses import get_function_tool, get_function_tool_call, get_text_message -from tests.utils.simple_session import SimpleListSession +from tests.utils.simple_session import RewriteAwareSimpleSession, SimpleListSession class TestIsOpenAIModelName: @@ -136,6 +136,107 @@ async def test_get_items_delegates(self) -> None: assert len(result) == 1 mock_session.get_items.assert_called_once() + @pytest.mark.asyncio + async def test_apply_history_mutations_rewrites_underlying_history(self) -> None: + underlying = SimpleListSession( + history=[ + cast(TResponseInputItem, {"type": "message", "role": "user", "content": "hello"}), + cast( + TResponseInputItem, + { + "type": "function_call", + "call_id": "call-1", + "id": "fc_1", + "name": "test_tool", + "arguments": '{"value":"foo"}', + }, + ), + cast( + TResponseInputItem, + { + "type": "function_call_output", + "call_id": "call-1", + "output": "ok", + }, + ), + ] + ) + session = OpenAIResponsesCompactionSession( + session_id="test", + underlying_session=underlying, + ) + + await session.apply_history_mutations( + { + "mutations": [ + { + "type": "replace_function_call", + "call_id": "call-1", + "replacement": cast( + TResponseInputItem, + { + "type": "function_call", + "call_id": "call-1", + "id": "fc_1", + "name": "test_tool", + "arguments": '{"value":"bar"}', + }, + ), + } + ] + } + ) + + saved_items = await underlying.get_items() + assert cast(dict[str, Any], saved_items[1])["arguments"] == '{"value":"bar"}' + + @pytest.mark.asyncio + async def test_apply_history_mutations_delegates_to_rewrite_aware_underlying_session( + self, + ) -> None: + underlying = RewriteAwareSimpleSession( + history=[ + cast( + TResponseInputItem, + { + "type": "function_call", + "call_id": "call-1", + "id": "fc_1", + "name": "test_tool", + "arguments": '{"value":"foo"}', + }, + ) + ] + ) + session = OpenAIResponsesCompactionSession( + session_id="test", + underlying_session=underlying, + ) + + await session.apply_history_mutations( + { + "mutations": [ + { + "type": "replace_function_call", + "call_id": "call-1", + "replacement": cast( + TResponseInputItem, + { + "type": "function_call", + "call_id": "call-1", + "id": "fc_1", + "name": "test_tool", + "arguments": '{"value":"bar"}', + }, + ), + } + ] + } + ) + + saved_items = await underlying.get_items() + assert cast(dict[str, Any], saved_items[0])["arguments"] == '{"value":"bar"}' + @pytest.mark.asyncio async def test_run_compaction_requires_response_id(self) -> None: mock_session = self.create_mock_session() @@ -247,6 +348,74 @@ async def test_run_compaction_auto_uses_input_when_store_false(self) -> None: assert "previous_response_id" not in call_kwargs assert call_kwargs.get("input") == items + @pytest.mark.asyncio + async def test_run_compaction_forces_input_mode_after_local_history_rewrite(self) -> None: + underlying = RewriteAwareSimpleSession( + history=[ + cast(TResponseInputItem, {"type": "message", "role": "user", "content": "hello"}), + cast( + TResponseInputItem, + { + "type": "function_call", + "call_id": "call-1", + "id": "fc_1", + "name": "test_tool", + "arguments": '{"value":"foo"}', + }, + ), + cast( + TResponseInputItem, + { + "type": "function_call_output", + "call_id": "call-1", + "output": "ok", + }, + ), + ] + ) + mock_compact_response = MagicMock() + mock_compact_response.output = [] + mock_client = MagicMock() + mock_client.responses.compact = AsyncMock(return_value=mock_compact_response) + session = OpenAIResponsesCompactionSession( + session_id="test", + underlying_session=underlying, + client=mock_client, + compaction_mode="auto", + ) + + await session.apply_history_mutations( + { + "mutations": [ + { + "type": "replace_function_call", + "call_id": "call-1", + "replacement": cast( + TResponseInputItem, + { + "type": "function_call", + "call_id": "call-1", + "id": "fc_1", + "name": "test_tool", + "arguments": '{"value":"bar"}', + }, + ), + } + ] + } + ) + await session.run_compaction({"response_id": "resp-1", "force": True}) + + first_call_kwargs = mock_client.responses.compact.call_args.kwargs + assert "previous_response_id" not in first_call_kwargs + assert isinstance(first_call_kwargs.get("input"), list) + + mock_client.responses.compact.reset_mock() + await session.run_compaction({"response_id": "resp-2", "force": True}) + + second_call_kwargs = mock_client.responses.compact.call_args.kwargs + assert second_call_kwargs.get("previous_response_id") == "resp-2" + @pytest.mark.asyncio async def test_run_compaction_auto_uses_default_store_when_unset(self) -> None: mock_session = self.create_mock_session() @@ -279,8 +448,8 @@ async def test_run_compaction_auto_uses_default_store_when_unset(self) -> None: first_kwargs = mock_client.responses.compact.call_args_list[0].kwargs second_kwargs = mock_client.responses.compact.call_args_list[1].kwargs assert "previous_response_id" not in first_kwargs - assert second_kwargs.get("previous_response_id") == "resp-stored" - assert "input" not in second_kwargs + assert "previous_response_id" not in second_kwargs + assert second_kwargs.get("input") == [] @pytest.mark.asyncio async def test_run_compaction_auto_uses_input_when_last_response_unstored(self) -> None: diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index 8b07297167..d92bcfd025 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -89,7 +89,13 @@ ) from .utils.factories import make_run_state from .utils.hitl import make_context_wrapper, make_model_and_agent, make_shell_call -from .utils.simple_session import CountingSession, IdStrippingSession, SimpleListSession +from .utils.simple_session import ( + CountingSession, + IdStrippingSession, + RewriteAwareSimpleSession, + ServerManagedSimpleSession, + SimpleListSession, +) class _DummyRunItem: @@ -1094,6 +1100,152 @@ def delegate_tool() -> str: assert state._current_agent is delegate +@pytest.mark.asyncio +async def test_resume_with_durable_override_rewrites_local_session_history() -> None: + model = FakeModel() + + @function_tool(name_override="approval_tool", needs_approval=True) + def approval_tool(test: str) -> str: + return f"result:{test}" + + agent = Agent(name="approval_agent", model=model, tools=[approval_tool]) + session = RewriteAwareSimpleSession() + + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "approval_tool", json.dumps({"test": "foo"}), call_id="call-1" + ) + ], + [get_text_message("done")], + ] + ) + + first = await Runner.run(agent, input="user_message", session=session) + assert first.interruptions + + state = first.to_state() + state.approve(first.interruptions[0], override_arguments={"test": "bar"}) + + resumed = await Runner.run(agent, state, session=session) + + assert resumed.final_output == "done" + saved_items = await session.get_items() + assert saved_items[1]["type"] == "function_call" + assert cast(dict[str, Any], saved_items[1])["arguments"] == json.dumps({"test": "bar"}) + assert saved_items[2]["type"] == "function_call_output" + assert saved_items[2]["call_id"] == "call-1" + + +@pytest.mark.asyncio +async def test_resume_rejects_execution_only_override_without_server_managed_history() -> None: + model = FakeModel() + + @function_tool(name_override="approval_tool", needs_approval=True) + def approval_tool(test: str) -> str: + return f"result:{test}" + + agent = Agent(name="approval_agent", model=model, tools=[approval_tool]) + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "approval_tool", json.dumps({"test": "foo"}), call_id="call-1" + ) + ], + [get_text_message("done")], + ] + ) + + first = await Runner.run(agent, input="user_message") + assert first.interruptions + + state = first.to_state() + state.approve( + first.interruptions[0], + override_arguments={"test": "bar"}, + save_override_arguments=False, + ) + + with pytest.raises(UserError, match="save_override_arguments=False is only supported"): + await Runner.run(agent, state) + + +@pytest.mark.asyncio +async def test_resume_supports_execution_only_override_with_server_managed_session() -> None: + model = FakeModel() + + @function_tool(name_override="approval_tool", needs_approval=True) + def approval_tool(test: str) -> str: + return f"result:{test}" + + agent = Agent( + name="approval_agent", + model=model, + tools=[approval_tool], + tool_use_behavior="stop_on_first_tool", + ) + session = ServerManagedSimpleSession() + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "approval_tool", json.dumps({"test": "foo"}), call_id="call-1" + ) + ], + ] + ) + + first = await Runner.run(agent, input="user_message", session=session) + assert first.interruptions + + state = first.to_state() + state.approve( + first.interruptions[0], + override_arguments={"test": "bar"}, + save_override_arguments=False, + ) + + resumed = await Runner.run(agent, state, session=session) + + assert resumed.final_output == "result:bar" + saved_items = await session.get_items() + assert cast(dict[str, Any], saved_items[1])["arguments"] == json.dumps({"test": "foo"}) + assert saved_items[2]["type"] == "function_call_output" + + +@pytest.mark.asyncio +async def test_resume_rejects_durable_override_for_non_rewrite_aware_session() -> None: + model = FakeModel() + + @function_tool(name_override="approval_tool", needs_approval=True) + def approval_tool(test: str) -> str: + return f"result:{test}" + + agent = Agent(name="approval_agent", model=model, tools=[approval_tool]) + session = SimpleListSession() + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "approval_tool", json.dumps({"test": "foo"}), call_id="call-1" + ) + ], + [get_text_message("done")], + ] + ) + + first = await Runner.run(agent, input="user_message", session=session) + assert first.interruptions + + state = first.to_state() + state.approve(first.interruptions[0], override_arguments={"test": "bar"}) + + with pytest.raises(UserError, match="supports persisted-history rewrites"): + await Runner.run(agent, state, session=session) + + class Foo(TypedDict): bar: str diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index 0e729fed37..1abfb28681 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -59,7 +59,11 @@ queue_function_call_and_text, resume_streamed_after_first_approval, ) -from .utils.simple_session import SimpleListSession +from .utils.simple_session import ( + RewriteAwareSimpleSession, + ServerManagedSimpleSession, + SimpleListSession, +) def _conversation_locked_error() -> BadRequestError: @@ -1573,6 +1577,74 @@ async def test_tool() -> str: assert output_count == 1 +@pytest.mark.asyncio +async def test_streaming_resume_with_durable_override_rewrites_session_history() -> None: + async def test_tool(test: str) -> str: + return f"result:{test}" + + tool = function_tool(test_tool, name_override="test_tool", needs_approval=True) + model, agent = make_model_and_agent(name="test", tools=[tool]) + session = RewriteAwareSimpleSession() + + queue_function_call_and_text( + model, + get_function_tool_call("test_tool", json.dumps({"test": "foo"}), call_id="call-resume"), + followup=[get_text_message("done")], + ) + + first = Runner.run_streamed(agent, input="Use test_tool", session=session) + await consume_stream(first) + assert first.interruptions + + state = first.to_state() + state.approve(first.interruptions[0], override_arguments={"test": "bar"}) + + resumed = Runner.run_streamed(agent, state, session=session) + await consume_stream(resumed) + + assert resumed.final_output == "done" + saved_items = await session.get_items() + assert cast(dict[str, Any], saved_items[1])["arguments"] == json.dumps({"test": "bar"}) + + +@pytest.mark.asyncio +async def test_streaming_resume_supports_execution_only_override_with_server_managed_session() -> ( + None +): + async def test_tool(test: str) -> str: + return f"result:{test}" + + tool = function_tool(test_tool, name_override="test_tool", needs_approval=True) + model = FakeModel() + agent = Agent( + name="test", + model=model, + tools=[tool], + tool_use_behavior="stop_on_first_tool", + ) + session = ServerManagedSimpleSession() + + model.add_multiple_turn_outputs( + [[get_function_tool_call("test_tool", json.dumps({"test": "foo"}), call_id="call-resume")]] + ) + + first = Runner.run_streamed(agent, input="Use test_tool", session=session) + await consume_stream(first) + assert first.interruptions + + state = first.to_state() + state.approve( + first.interruptions[0], + override_arguments={"test": "bar"}, + save_override_arguments=False, + ) + + resumed = Runner.run_streamed(agent, state, session=session) + await consume_stream(resumed) + + assert resumed.final_output == "result:bar" + + @pytest.mark.asyncio async def test_streaming_resume_preserves_filtered_model_input_after_handoff(): model = FakeModel() diff --git a/tests/test_agent_tracing.py b/tests/test_agent_tracing.py index 14ab62b2b2..e954d74253 100644 --- a/tests/test_agent_tracing.py +++ b/tests/test_agent_tracing.py @@ -245,6 +245,102 @@ async def test_resumed_run_with_workflow_override_starts_new_trace() -> None: assert [trace.name for trace in traces] == ["original_workflow", "override_workflow"] +@pytest.mark.asyncio +async def test_approval_override_records_custom_trace_span() -> None: + model = FakeModel() + + @function_tool(name_override="send_email", needs_approval=True) + def send_email(recipient: str) -> str: + return recipient + + agent = Agent(name="trace_agent", model=model, tools=[send_email]) + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "send_email", '{"recipient":"alice@example.com"}', call_id="call-1" + ) + ] + ] + ) + + first = await Runner.run(agent, input="first_test") + assert first.interruptions + + state = first.to_state() + before_custom_spans = [ + span + for span in fetch_ordered_spans() + if getattr(getattr(span, "span_data", None), "type", None) == "custom" + ] + + state.approve(first.interruptions[0], override_arguments={"recipient": "bob@example.com"}) + + after_custom_spans = [ + span + for span in fetch_ordered_spans() + if getattr(getattr(span, "span_data", None), "type", None) == "custom" + ] + + assert len(after_custom_spans) == len(before_custom_spans) + 1 + override_span = after_custom_spans[-1].export() + assert override_span is not None + assert override_span["span_data"]["name"] == "approval override: send_email" + assert override_span["span_data"]["data"] == { + "tool_name": "send_email", + "call_id": "call-1", + "original_arguments": {"recipient": "alice@example.com"}, + "override_arguments": {"recipient": "bob@example.com"}, + } + + +@pytest.mark.asyncio +async def test_approval_override_respects_restored_sensitive_trace_flag() -> None: + model = FakeModel() + + @function_tool(name_override="send_email", needs_approval=True) + def send_email(recipient: str) -> str: + return recipient + + agent = Agent(name="trace_agent", model=model, tools=[send_email]) + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "send_email", '{"recipient":"alice@example.com"}', call_id="call-1" + ) + ] + ] + ) + + first = await Runner.run(agent, input="first_test") + assert first.interruptions + + state = first.to_state() + state.set_trace_include_sensitive_data(False) + restored_state = await RunState.from_string(agent, state.to_string()) + restored_interruptions = restored_state.get_interruptions() + assert restored_interruptions + + before_custom_spans = [ + span + for span in fetch_ordered_spans() + if getattr(getattr(span, "span_data", None), "type", None) == "custom" + ] + + restored_state.approve( + restored_interruptions[0], + override_arguments={"recipient": "bob@example.com"}, + ) + + after_custom_spans = [ + span + for span in fetch_ordered_spans() + if getattr(getattr(span, "span_data", None), "type", None) == "custom" + ] + assert len(after_custom_spans) == len(before_custom_spans) + + @pytest.mark.asyncio async def test_wrapped_trace_is_single_trace(): model = FakeModel() diff --git a/tests/test_run_state.py b/tests/test_run_state.py index 56cd61fab2..254665deab 100644 --- a/tests/test_run_state.py +++ b/tests/test_run_state.py @@ -177,6 +177,56 @@ def set_last_processed_response( state._last_processed_response = make_processed_response(new_items=new_items) +def build_overrideable_approval_state( + *, + conversation_id: str | None = None, + previous_response_id: str | None = None, + auto_previous_response_id: bool = False, +) -> tuple[RunState[Any, Agent[Any]], ToolApprovalItem, ResponseFunctionToolCall]: + """Build a RunState whose interruption can override function-call arguments.""" + + @function_tool(name_override="send_email") + def send_email(recipient: str) -> str: + return f"sent:{recipient}" + + context: RunContextWrapper[dict[str, Any]] = RunContextWrapper(context={}) + agent = Agent(name="OverrideAgent", tools=[send_email]) + raw_item = ResponseFunctionToolCall( + type="function_call", + id="fc_override", + call_id="call-override", + name="send_email", + arguments=json.dumps({"recipient": "alice@example.com"}), + ) + tool_call_item = ToolCallItem(agent=agent, raw_item=raw_item) + approval_item = ToolApprovalItem( + agent=agent, + raw_item=raw_item, + tool_name="send_email", + ) + state = make_state(agent, context=context, original_input="input", max_turns=2) + state._conversation_id = conversation_id + state._previous_response_id = previous_response_id + state._auto_previous_response_id = auto_previous_response_id + state._generated_items = [tool_call_item] + state._session_items = [ToolCallItem(agent=agent, raw_item=raw_item)] + state._current_step = NextStepInterruption(interruptions=[approval_item]) + state._model_responses = [ + ModelResponse( + output=[raw_item], + usage=Usage(), + response_id="resp-override", + ) + ] + state._last_processed_response = make_processed_response( + new_items=[ToolCallItem(agent=agent, raw_item=raw_item)], + functions=[ToolRunFunction(tool_call=raw_item, function_tool=send_email)], + interruptions=[approval_item], + ) + state._mark_generated_items_merged_with_last_processed() + return state, approval_item, raw_item + + class TestRunState: """Test RunState initialization, serialization, and core functionality.""" @@ -364,6 +414,107 @@ def test_approve_updates_context_approvals_correctly(self): assert state._context is not None assert state._context.is_tool_approved(tool_name="toolX", call_id="cid123") is True + def test_approve_with_override_arguments_updates_durable_replay_state(self): + """approve() should update replay history and record a session history mutation.""" + state, approval_item, _ = build_overrideable_approval_state() + + state.approve(approval_item, override_arguments={"recipient": "bob@example.com"}) + + assert approval_item.arguments == json.dumps({"recipient": "bob@example.com"}) + assert cast(Any, state._generated_items[0].raw_item).arguments == json.dumps( + {"recipient": "bob@example.com"} + ) + assert cast(Any, state._session_items[0].raw_item).arguments == json.dumps( + {"recipient": "bob@example.com"} + ) + assert ( + state._last_processed_response is not None + and state._last_processed_response.functions[0].tool_call.arguments + == json.dumps({"recipient": "bob@example.com"}) + ) + assert cast(Any, state._model_responses[0].output[0]).arguments == json.dumps( + {"recipient": "bob@example.com"} + ) + assert state.get_session_history_mutations() == [ + { + "type": "replace_function_call", + "call_id": "call-override", + "replacement": { + "type": "function_call", + "id": "fc_override", + "call_id": "call-override", + "name": "send_email", + "arguments": json.dumps({"recipient": "bob@example.com"}), + }, + } + ] + assert state.has_pending_execution_only_approval_overrides() is False + + def test_approve_with_execution_only_override_keeps_replay_history_unchanged(self): + """Execution-only overrides should only affect the pending execution surface.""" + state, approval_item, raw_item = build_overrideable_approval_state() + + state.approve( + approval_item, + override_arguments={"recipient": "bob@example.com"}, + save_override_arguments=False, + ) + + assert approval_item.arguments == json.dumps({"recipient": "bob@example.com"}) + assert cast(Any, state._generated_items[0].raw_item).arguments == json.dumps( + {"recipient": "alice@example.com"} + ) + assert cast(Any, state._session_items[0].raw_item).arguments == json.dumps( + {"recipient": "alice@example.com"} + ) + assert state._last_processed_response is not None + assert state._last_processed_response.functions[0].tool_call.arguments == json.dumps( + {"recipient": "bob@example.com"} + ) + assert state._model_responses[0].output[0] == raw_item + assert state.get_session_history_mutations() == [] + assert state.has_pending_execution_only_approval_overrides() is True + + def test_approve_with_override_arguments_rejects_server_managed_conversation_defaults(self): + """Durable overrides should fail fast for server-managed conversations.""" + state, approval_item, _ = build_overrideable_approval_state(previous_response_id="resp-1") + + with pytest.raises( + UserError, match="save_override_arguments requires local canonical history" + ): + state.approve(approval_item, override_arguments={"recipient": "bob@example.com"}) + + def test_approve_with_override_arguments_validates_options(self): + """approve() should reject invalid override option combinations.""" + state, approval_item, _ = build_overrideable_approval_state() + + with pytest.raises(UserError, match="save_override_arguments can only be used"): + state.approve(approval_item, save_override_arguments=False) + with pytest.raises(UserError, match="cannot be used together with always_approve"): + state.approve( + approval_item, + always_approve=True, + override_arguments={"recipient": "bob@example.com"}, + ) + with pytest.raises(UserError, match="plain JSON object"): + state.approve(approval_item, override_arguments=cast(Any, ["not", "a", "dict"])) + + async def test_override_state_roundtrips_through_serialization(self): + """Override-specific bookkeeping should survive RunState serialization.""" + state, approval_item, _ = build_overrideable_approval_state() + state.set_trace_include_sensitive_data(False) + state.approve( + approval_item, + override_arguments={"recipient": "bob@example.com"}, + save_override_arguments=False, + ) + + restored = await RunState.from_string(state._current_agent, state.to_string()) # type: ignore[arg-type] + + assert restored._trace_include_sensitive_data is False + assert restored.has_pending_execution_only_approval_overrides() is True + assert restored.get_session_history_mutations() == [] + def test_returns_undefined_when_approval_status_is_unknown(self): """Test that isToolApproved returns None for unknown tools.""" context: RunContextWrapper[dict[str, str]] = RunContextWrapper(context={}) @@ -3969,7 +4120,7 @@ async def test_from_json_missing_schema_version(self): await RunState.from_json(agent, state_json) @pytest.mark.asyncio - @pytest.mark.parametrize("schema_version", ["1.7", "2.0"]) + @pytest.mark.parametrize("schema_version", ["1.8", "2.0"]) async def test_from_json_unsupported_schema_version(self, schema_version: str): """Test that from_json raises error when schema version is unsupported.""" agent = Agent(name="TestAgent") @@ -4021,7 +4172,7 @@ async def test_from_json_accepts_previous_schema_version(self): def test_supported_schema_versions_match_released_boundary(self): """The support set should include released versions plus the current unreleased writer.""" assert SUPPORTED_SCHEMA_VERSIONS == frozenset( - {"1.0", "1.1", "1.2", "1.3", "1.4", "1.5", CURRENT_SCHEMA_VERSION} + {"1.0", "1.1", "1.2", "1.3", "1.4", "1.5", "1.6", CURRENT_SCHEMA_VERSION} ) @pytest.mark.asyncio diff --git a/tests/utils/simple_session.py b/tests/utils/simple_session.py index 94bcc97e9e..75de503322 100644 --- a/tests/utils/simple_session.py +++ b/tests/utils/simple_session.py @@ -3,7 +3,11 @@ from typing import cast from agents.items import TResponseInputItem -from agents.memory.session import Session +from agents.memory.session import ( + Session, + SessionHistoryRewriteArgs, + apply_session_history_mutations, +) from agents.memory.session_settings import SessionSettings @@ -80,3 +84,17 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: else: sanitized.append(item) await super().add_items(sanitized) + + +class RewriteAwareSimpleSession(SimpleListSession): + """In-memory test session that supports persisted-history rewrites.""" + + async def apply_history_mutations(self, args: SessionHistoryRewriteArgs) -> None: + self._items = apply_session_history_mutations(self._items, args.get("mutations", [])) + self.saved_items = self._items + + +class ServerManagedSimpleSession(SimpleListSession): + """In-memory test session that advertises server-managed history semantics.""" + + _server_managed_conversation_session = True From 2bb98d889c90d2e7c1ad87b972c0f42b7f750975 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Tue, 17 Mar 2026 15:43:51 +0900 Subject: [PATCH 2/5] fix review comments --- .../openai_responses_compaction_session.py | 53 ++++++++++++++++--- src/agents/run.py | 12 ++--- .../run_internal/agent_runner_helpers.py | 17 ++++-- ...est_openai_responses_compaction_session.py | 4 +- tests/test_agent_tracing.py | 17 ++++++ 5 files changed, 83 insertions(+), 20 deletions(-) diff --git a/src/agents/memory/openai_responses_compaction_session.py b/src/agents/memory/openai_responses_compaction_session.py index a41f78a8f6..24959de0f9 100644 --- a/src/agents/memory/openai_responses_compaction_session.py +++ b/src/agents/memory/openai_responses_compaction_session.py @@ -185,17 +185,58 @@ def _resolve_compaction_mode( return resolved_mode + def _resolve_store_tracking( + self, + *, + response_id: str | None, + previous_response_id: str | None, + store: bool | None, + store_was_provided: bool, + ) -> bool | None: + """Resolve the effective store setting for the current response id. + + Reuse `_last_store` only while compaction still refers to the same response. A new + response id with no explicit `store` falls back to the Responses API default behavior. + """ + if store_was_provided: + self._last_store = store + return store + + if response_id is not None and response_id != previous_response_id: + self._last_store = None + return None + + return self._last_store + + def _get_effective_store_for_response_id( + self, + *, + response_id: str | None, + store: bool | None, + ) -> bool | None: + """Return the effective store setting without mutating response tracking.""" + if store is not None: + return store + if response_id is not None and response_id != self._response_id: + return None + return self._last_store + async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None) -> None: """Run compaction using responses.compact API.""" previous_response_id = self._response_id if args and args.get("response_id"): self._response_id = args["response_id"] requested_mode = args.get("compaction_mode") if args else None - if args and "store" in args: - store: bool | None = args["store"] - self._last_store = store - else: - store = self._last_store + store_was_provided = bool(args and "store" in args) + requested_store: bool | None = ( + args["store"] if args is not None and "store" in args else None + ) + store = self._resolve_store_tracking( + response_id=self._response_id, + previous_response_id=previous_response_id, + store=requested_store, + store_was_provided=store_was_provided, + ) turn_has_local_adds_without_new_response_id = ( self._has_unacknowledged_local_session_adds and (args is None or args.get("response_id") in {None, previous_response_id}) @@ -320,7 +361,7 @@ async def _defer_compaction(self, response_id: str, store: bool | None = None) - resolved_mode = _resolve_compaction_mode( self.compaction_mode, response_id=response_id, - store=store if store is not None else self._last_store, + store=self._get_effective_store_for_response_id(response_id=response_id, store=store), ) should_compact = self.should_trigger_compaction( { diff --git a/src/agents/run.py b/src/agents/run.py index d073a3ec21..36656e5ed8 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -47,6 +47,7 @@ from .run_internal.agent_runner_helpers import ( append_model_response_if_new, apply_resumed_conversation_settings, + attach_run_state_metadata, build_interruption_result, build_resumed_stream_debug_extra, ensure_context_wrapper, @@ -828,8 +829,6 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: result._replay_from_model_input_items = list( generated_items ) != list(session_items) - if run_state is not None: - result._trace_state = run_state._trace_state if session_persistence_enabled: input_items_for_save_1: list[TResponseInputItem] = ( session_input_items_for_persistence @@ -844,6 +843,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: response_id=turn_result.model_response.response_id, store=store_setting, ) + attach_run_state_metadata(result, run_state=run_state) result._original_input = copy_input_items(original_input) return finalize_conversation_tracking( _with_reasoning_item_id_policy(result), @@ -965,8 +965,6 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: result._replay_from_model_input_items = list(generated_items) != list( session_items ) - if run_state is not None: - result._trace_state = run_state._trace_state if session_persistence_enabled and include_in_history: handler_input_items_for_save: list[TResponseInputItem] = ( session_input_items_for_persistence @@ -981,6 +979,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: response_id=None, store=store_setting, ) + attach_run_state_metadata(result, run_state=run_state) result._original_input = copy_input_items(original_input) return finalize_conversation_tracking( _with_reasoning_item_id_policy(result), @@ -1236,10 +1235,6 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: result._replay_from_model_input_items = list(generated_items) != list( session_items ) - if run_state is not None: - result._current_turn_persisted_item_count = ( - run_state._current_turn_persisted_item_count - ) await save_turn_items_if_needed( session=session, run_state=run_state, @@ -1249,6 +1244,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: response_id=turn_result.model_response.response_id, store=store_setting, ) + attach_run_state_metadata(result, run_state=run_state) result._original_input = copy_input_items(original_input) return finalize_conversation_tracking( _with_reasoning_item_id_policy(result), diff --git a/src/agents/run_internal/agent_runner_helpers.py b/src/agents/run_internal/agent_runner_helpers.py index c40850e4d9..9784bee728 100644 --- a/src/agents/run_internal/agent_runner_helpers.py +++ b/src/agents/run_internal/agent_runner_helpers.py @@ -33,6 +33,7 @@ from .tool_use_tracker import AgentToolUseTracker, serialize_tool_use_tracker __all__ = [ + "attach_run_state_metadata", "apply_resumed_conversation_settings", "append_model_response_if_new", "build_generated_items_details", @@ -276,6 +277,17 @@ def finalize_conversation_tracking( return result +def attach_run_state_metadata(result: RunResult, *, run_state: RunState | None) -> RunResult: + """Copy resumable state metadata from the current RunState onto a RunResult.""" + if run_state is None: + return result + + result._state = run_state + result._current_turn_persisted_item_count = run_state._current_turn_persisted_item_count + result._trace_state = run_state._trace_state + return result + + def build_interruption_result( *, result_input: str | list[TResponseInputItem], @@ -315,10 +327,7 @@ def build_interruption_result( result._current_turn = current_turn result._model_input_items = list(generated_items) result._replay_from_model_input_items = list(generated_items) != list(session_items) - if run_state is not None: - result._state = run_state - result._current_turn_persisted_item_count = run_state._current_turn_persisted_item_count - result._trace_state = run_state._trace_state + attach_run_state_metadata(result, run_state=run_state) result._original_input = copy_input_items(original_input) return result diff --git a/tests/memory/test_openai_responses_compaction_session.py b/tests/memory/test_openai_responses_compaction_session.py index 4e8a4c4ade..904127f356 100644 --- a/tests/memory/test_openai_responses_compaction_session.py +++ b/tests/memory/test_openai_responses_compaction_session.py @@ -448,8 +448,8 @@ async def test_run_compaction_auto_uses_default_store_when_unset(self) -> None: first_kwargs = mock_client.responses.compact.call_args_list[0].kwargs second_kwargs = mock_client.responses.compact.call_args_list[1].kwargs assert "previous_response_id" not in first_kwargs - assert "previous_response_id" not in second_kwargs - assert second_kwargs.get("input") == [] + assert second_kwargs.get("previous_response_id") == "resp-stored" + assert "input" not in second_kwargs @pytest.mark.asyncio async def test_run_compaction_auto_uses_input_when_last_response_unstored(self) -> None: diff --git a/tests/test_agent_tracing.py b/tests/test_agent_tracing.py index e954d74253..e6b1127cc7 100644 --- a/tests/test_agent_tracing.py +++ b/tests/test_agent_tracing.py @@ -341,6 +341,23 @@ def send_email(recipient: str) -> str: assert len(after_custom_spans) == len(before_custom_spans) +@pytest.mark.asyncio +async def test_completed_result_to_state_preserves_sensitive_trace_flag() -> None: + model = FakeModel() + model.add_multiple_turn_outputs([[get_text_message("done")]]) + agent = Agent(name="trace_agent", model=model) + + result = await Runner.run( + agent, + input="first_test", + run_config=RunConfig(trace_include_sensitive_data=False), + ) + + state = result.to_state() + + assert state._trace_include_sensitive_data is False + + @pytest.mark.asyncio async def test_wrapped_trace_is_single_trace(): model = FakeModel() From c2ab16acfc3c11f5440a7abd4d4197c1b085f90c Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Tue, 17 Mar 2026 15:52:08 +0900 Subject: [PATCH 3/5] fix review comments --- .../openai_responses_compaction_session.py | 15 ++- src/agents/run.py | 21 +++- .../run_internal/agent_runner_helpers.py | 13 +++ ...est_openai_responses_compaction_session.py | 69 ++++++++++++++ tests/test_agent_tracing.py | 95 +++++++++++++++++++ 5 files changed, 200 insertions(+), 13 deletions(-) diff --git a/src/agents/memory/openai_responses_compaction_session.py b/src/agents/memory/openai_responses_compaction_session.py index 24959de0f9..5661f69170 100644 --- a/src/agents/memory/openai_responses_compaction_session.py +++ b/src/agents/memory/openai_responses_compaction_session.py @@ -168,15 +168,6 @@ def _resolve_compaction_mode( if not self._has_pending_local_history_rewrite: return resolved_mode - if ( - self._local_history_rewrite_response_id is not None - and response_id is not None - and response_id != self._local_history_rewrite_response_id - ): - self._has_pending_local_history_rewrite = False - self._local_history_rewrite_response_id = None - return resolved_mode - if resolved_mode == "previous_response_id": if self._local_history_rewrite_response_id is None and response_id is not None: self._local_history_rewrite_response_id = response_id @@ -321,6 +312,8 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None self._compaction_candidate_items = select_compaction_candidate_items(output_items) self._session_items = output_items + if resolved_mode == "input": + self._clear_pending_local_history_rewrite() logger.debug( f"compact: done for {self._response_id} " @@ -435,6 +428,10 @@ def _mark_local_history_rewrite(self) -> None: self._has_pending_local_history_rewrite = True self._local_history_rewrite_response_id = self._response_id + def _clear_pending_local_history_rewrite(self) -> None: + self._has_pending_local_history_rewrite = False + self._local_history_rewrite_response_id = None + _ResolvedCompactionMode = Literal["previous_response_id", "input"] diff --git a/src/agents/run.py b/src/agents/run.py index 36656e5ed8..ce33aa6c3d 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -55,6 +55,7 @@ input_guardrails_triggered, resolve_processed_response, resolve_resumed_context, + resolve_trace_include_sensitive_data, resolve_trace_settings, save_turn_items_if_needed, should_cancel_parallel_model_task_on_input_guardrail_trip, @@ -412,6 +413,7 @@ async def run( auto_previous_response_id = kwargs.get("auto_previous_response_id", False) conversation_id = kwargs.get("conversation_id") session = kwargs.get("session") + run_config_was_supplied = run_config is not None if run_config is None: run_config = RunConfig() @@ -511,10 +513,15 @@ async def run( history_is_server_managed=history_is_server_managed, ) - if is_resumed_state and run_state is not None: + resolved_trace_include_sensitive_data = resolve_trace_include_sensitive_data( + run_state=run_state, + run_config=run_config, + run_config_was_supplied=run_config_was_supplied, + ) + if resolved_trace_include_sensitive_data != run_config.trace_include_sensitive_data: run_config = dataclasses.replace( run_config, - trace_include_sensitive_data=run_state._trace_include_sensitive_data, + trace_include_sensitive_data=resolved_trace_include_sensitive_data, ) resolved_reasoning_item_id_policy: ReasoningItemIdPolicy | None = ( @@ -1462,6 +1469,7 @@ def run_streamed( auto_previous_response_id = kwargs.get("auto_previous_response_id", False) conversation_id = kwargs.get("conversation_id") session = kwargs.get("session") + run_config_was_supplied = run_config is not None if run_config is None: run_config = RunConfig() @@ -1553,10 +1561,15 @@ def run_streamed( session=session, history_is_server_managed=history_is_server_managed, ) - if is_resumed_state and run_state is not None: + resolved_trace_include_sensitive_data = resolve_trace_include_sensitive_data( + run_state=run_state, + run_config=run_config, + run_config_was_supplied=run_config_was_supplied, + ) + if resolved_trace_include_sensitive_data != run_config.trace_include_sensitive_data: run_config = dataclasses.replace( run_config, - trace_include_sensitive_data=run_state._trace_include_sensitive_data, + trace_include_sensitive_data=resolved_trace_include_sensitive_data, ) resolved_reasoning_item_id_policy: ReasoningItemIdPolicy | None = ( diff --git a/src/agents/run_internal/agent_runner_helpers.py b/src/agents/run_internal/agent_runner_helpers.py index 9784bee728..1dcf7d5a42 100644 --- a/src/agents/run_internal/agent_runner_helpers.py +++ b/src/agents/run_internal/agent_runner_helpers.py @@ -43,6 +43,7 @@ "ensure_context_wrapper", "finalize_conversation_tracking", "input_guardrails_triggered", + "resolve_trace_include_sensitive_data", "validate_session_conversation_settings", "resolve_trace_settings", "resolve_processed_response", @@ -178,6 +179,18 @@ def resolve_trace_settings( return workflow_name, trace_id, group_id, metadata, tracing +def resolve_trace_include_sensitive_data( + *, + run_state: RunState[TContext] | None, + run_config: RunConfig, + run_config_was_supplied: bool, +) -> bool: + """Resolve whether traces may include sensitive data for this run.""" + if run_state is None or run_config_was_supplied: + return run_config.trace_include_sensitive_data + return run_state._trace_include_sensitive_data + + def resolve_resumed_context( *, run_state: RunState[TContext], diff --git a/tests/memory/test_openai_responses_compaction_session.py b/tests/memory/test_openai_responses_compaction_session.py index 904127f356..5b312eb72d 100644 --- a/tests/memory/test_openai_responses_compaction_session.py +++ b/tests/memory/test_openai_responses_compaction_session.py @@ -416,6 +416,75 @@ async def test_run_compaction_forces_input_mode_after_local_history_rewrite(self second_call_kwargs = mock_client.responses.compact.call_args.kwargs assert second_call_kwargs.get("previous_response_id") == "resp-2" + @pytest.mark.asyncio + async def test_run_compaction_keeps_local_rewrite_pending_until_input_compaction_succeeds( + self, + ) -> None: + underlying = RewriteAwareSimpleSession( + history=[ + cast(TResponseInputItem, {"type": "message", "role": "user", "content": "hello"}), + cast( + TResponseInputItem, + { + "type": "function_call", + "call_id": "call-1", + "id": "fc_1", + "name": "test_tool", + "arguments": '{"value":"foo"}', + }, + ), + cast( + TResponseInputItem, + { + "type": "function_call_output", + "call_id": "call-1", + "output": "ok", + }, + ), + ] + ) + mock_compact_response = MagicMock() + mock_compact_response.output = [] + mock_client = MagicMock() + mock_client.responses.compact = AsyncMock(return_value=mock_compact_response) + session = OpenAIResponsesCompactionSession( + session_id="test", + underlying_session=underlying, + client=mock_client, + compaction_mode="auto", + ) + + await session.apply_history_mutations( + { + "mutations": [ + { + "type": "replace_function_call", + "call_id": "call-1", + "replacement": cast( + TResponseInputItem, + { + "type": "function_call", + "call_id": "call-1", + "id": "fc_1", + "name": "test_tool", + "arguments": '{"value":"bar"}', + }, + ), + } + ] + } + ) + + await session.run_compaction({"response_id": "resp-1"}) + mock_client.responses.compact.assert_not_called() + + await session.run_compaction({"response_id": "resp-2", "force": True}) + + call_kwargs = mock_client.responses.compact.call_args.kwargs + assert "previous_response_id" not in call_kwargs + assert isinstance(call_kwargs.get("input"), list) + assert cast(dict[str, Any], call_kwargs["input"][1])["arguments"] == '{"value":"bar"}' + @pytest.mark.asyncio async def test_run_compaction_auto_uses_default_store_when_unset(self) -> None: mock_session = self.create_mock_session() diff --git a/tests/test_agent_tracing.py b/tests/test_agent_tracing.py index e6b1127cc7..920896b90c 100644 --- a/tests/test_agent_tracing.py +++ b/tests/test_agent_tracing.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from typing import Any from uuid import uuid4 import pytest @@ -27,6 +28,18 @@ def approval_tool() -> str: return Agent(name="test_agent", model=model, tools=[approval_tool]) +def _get_last_function_span_export(name: str) -> dict[str, Any]: + matching_spans = [ + exported + for span in fetch_ordered_spans() + if (exported := span.export()) is not None + and exported["span_data"]["type"] == "function" + and exported["span_data"]["name"] == name + ] + assert matching_spans + return matching_spans[-1] + + @pytest.mark.asyncio async def test_single_run_is_single_trace(): agent = Agent( @@ -358,6 +371,45 @@ async def test_completed_result_to_state_preserves_sensitive_trace_flag() -> Non assert state._trace_include_sensitive_data is False +@pytest.mark.asyncio +async def test_resumed_run_honors_explicit_trace_include_sensitive_data() -> None: + model = FakeModel() + + @function_tool(name_override="send_email", needs_approval=True) + def send_email(recipient: str) -> str: + return recipient + + agent = Agent(name="trace_agent", model=model, tools=[send_email]) + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "send_email", '{"recipient":"alice@example.com"}', call_id="call-1" + ) + ], + [get_text_message("done")], + ] + ) + + first = await Runner.run(agent, input="first_test") + assert first.interruptions + + state = first.to_state() + state.approve(first.interruptions[0], override_arguments={"recipient": "bob@example.com"}) + + resumed = await Runner.run( + agent, + state, + run_config=RunConfig(trace_include_sensitive_data=False), + ) + + assert resumed.final_output == "done" + assert state._trace_include_sensitive_data is False + function_span = _get_last_function_span_export("send_email") + assert function_span["span_data"]["input"] is None + assert function_span["span_data"]["output"] is None + + @pytest.mark.asyncio async def test_wrapped_trace_is_single_trace(): model = FakeModel() @@ -643,6 +695,49 @@ async def test_resumed_streaming_run_reuses_original_trace_without_duplicate_tra assert all(span.trace_id == traces[0].trace_id for span in fetch_ordered_spans()) +@pytest.mark.asyncio +async def test_resumed_streaming_run_honors_explicit_trace_include_sensitive_data() -> None: + model = FakeModel() + + @function_tool(name_override="send_email", needs_approval=True) + def send_email(recipient: str) -> str: + return recipient + + agent = Agent(name="trace_agent", model=model, tools=[send_email]) + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "send_email", '{"recipient":"alice@example.com"}', call_id="call-1" + ) + ], + [get_text_message("done")], + ] + ) + + first = Runner.run_streamed(agent, input="first_test") + async for _ in first.stream_events(): + pass + assert first.interruptions + + state = first.to_state() + state.approve(first.interruptions[0], override_arguments={"recipient": "bob@example.com"}) + + resumed = Runner.run_streamed( + agent, + state, + run_config=RunConfig(trace_include_sensitive_data=False), + ) + async for _ in resumed.stream_events(): + pass + + assert resumed.final_output == "done" + assert state._trace_include_sensitive_data is False + function_span = _get_last_function_span_export("send_email") + assert function_span["span_data"]["input"] is None + assert function_span["span_data"]["output"] is None + + @pytest.mark.asyncio async def test_wrapped_streaming_trace_is_single_trace(): model = FakeModel() From aa02f12ea8ac114f3057246bed8b031464d03c7e Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Tue, 17 Mar 2026 16:07:29 +0900 Subject: [PATCH 4/5] fix review comments --- .../openai_responses_compaction_session.py | 47 +++++++----- src/agents/result.py | 43 +++++++++-- .../run_internal/agent_runner_helpers.py | 20 ++++- ...est_openai_responses_compaction_session.py | 75 +++++++++++++++++++ tests/test_agent_tracing.py | 39 ++++++++++ tests/test_result_cast.py | 23 ++++++ 6 files changed, 222 insertions(+), 25 deletions(-) diff --git a/src/agents/memory/openai_responses_compaction_session.py b/src/agents/memory/openai_responses_compaction_session.py index 5661f69170..7e403b0a40 100644 --- a/src/agents/memory/openai_responses_compaction_session.py +++ b/src/agents/memory/openai_responses_compaction_session.py @@ -27,6 +27,14 @@ OpenAIResponsesCompactionMode = Literal["previous_response_id", "input", "auto"] +def _is_user_message_item(item: TResponseInputItem) -> bool: + if not isinstance(item, dict): + return False + if item.get("type") == "message": + return item.get("role") == "user" + return item.get("role") == "user" and "content" in item + + def select_compaction_candidate_items( items: list[TResponseInputItem], ) -> list[TResponseInputItem]: @@ -35,18 +43,12 @@ def select_compaction_candidate_items( Excludes user messages and compaction items. """ - def _is_user_message(item: TResponseInputItem) -> bool: - if not isinstance(item, dict): - return False - if item.get("type") == "message": - return item.get("role") == "user" - return item.get("role") == "user" and "content" in item - return [ item for item in items if not ( - _is_user_message(item) or (isinstance(item, dict) and item.get("type") == "compaction") + _is_user_message_item(item) + or (isinstance(item, dict) and item.get("type") == "compaction") ) ] @@ -272,12 +274,12 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None ) return - unresolved_function_calls = _find_unresolved_function_calls_without_results(session_items) - if unresolved_function_calls: + frontier_unresolved_function_calls = _find_frontier_unresolved_function_calls(session_items) + if frontier_unresolved_function_calls: logger.debug( "compact: blocked unresolved function calls for %s: %s", self._response_id, - unresolved_function_calls, + frontier_unresolved_function_calls, ) return @@ -436,12 +438,19 @@ def _clear_pending_local_history_rewrite(self) -> None: _ResolvedCompactionMode = Literal["previous_response_id", "input"] -def _find_unresolved_function_calls_without_results(items: list[TResponseInputItem]) -> list[str]: - """Return function-call ids that do not yet have matching outputs.""" - function_calls: dict[str, TResponseInputItem] = {} +def _find_frontier_unresolved_function_calls(items: list[TResponseInputItem]) -> list[str]: + """Return unresolved function-call ids that remain in the active conversation frontier. + + Once a later user message appears, earlier unresolved tool calls are considered abandoned and + should no longer block future compaction for the session. + """ + function_call_indices: dict[str, int] = {} resolved_call_ids: set[str] = set() + last_user_message_index = -1 - for item in items: + for index, item in enumerate(items): + if _is_user_message_item(item): + last_user_message_index = index if isinstance(item, dict): item_type = item.get("type") call_id = item.get("call_id") @@ -452,11 +461,15 @@ def _find_unresolved_function_calls_without_results(items: list[TResponseInputIt if not isinstance(call_id, str): continue if item_type == "function_call": - function_calls[call_id] = item + function_call_indices[call_id] = index elif item_type == "function_call_output": resolved_call_ids.add(call_id) - return [call_id for call_id in function_calls if call_id not in resolved_call_ids] + return [ + call_id + for call_id, index in function_call_indices.items() + if call_id not in resolved_call_ids and index > last_user_message_index + ] def _resolve_compaction_mode( diff --git a/src/agents/result.py b/src/agents/result.py index 355dab06e7..4e5e26f095 100644 --- a/src/agents/result.py +++ b/src/agents/result.py @@ -107,12 +107,36 @@ def _populate_state_from_result( if trace_state is None: trace_state = TraceState.from_trace(getattr(result, "trace", None)) state._trace_state = copy.deepcopy(trace_state) if trace_state else None - state._trace_include_sensitive_data = getattr( - source_state, - "_trace_include_sensitive_data", - True, + trace_include_sensitive_data_snapshot = getattr( + result, + "_trace_include_sensitive_data_snapshot", + None, ) - if isinstance(source_state, RunState): + if trace_include_sensitive_data_snapshot is not None: + state._trace_include_sensitive_data = trace_include_sensitive_data_snapshot + else: + state._trace_include_sensitive_data = getattr( + source_state, + "_trace_include_sensitive_data", + True, + ) + + session_history_mutations_snapshot = getattr( + result, + "_session_history_mutations_snapshot", + None, + ) + execution_only_approval_override_call_ids_snapshot = getattr( + result, + "_execution_only_approval_override_call_ids_snapshot", + None, + ) + if session_history_mutations_snapshot is not None: + state._session_history_mutations = copy.deepcopy(session_history_mutations_snapshot) + state._execution_only_approval_override_call_ids = list( + execution_only_approval_override_call_ids_snapshot or [] + ) + elif isinstance(source_state, RunState): state._session_history_mutations = source_state.get_session_history_mutations() state._execution_only_approval_override_call_ids = list( source_state._execution_only_approval_override_call_ids @@ -332,6 +356,15 @@ class RunResult(RunResultBase): to preserve the correct originalInput when serializing state.""" _state: Any = field(default=None, repr=False) """Internal reference to the originating RunState when available.""" + _trace_include_sensitive_data_snapshot: bool | None = field(default=None, repr=False) + """Snapshot of the trace redaction setting used when rebuilding state from a completed + result.""" + _session_history_mutations_snapshot: list[Any] | None = field(default=None, repr=False) + """Snapshot of pending session-history rewrites needed by `to_state()`.""" + _execution_only_approval_override_call_ids_snapshot: list[str] | None = field( + default=None, repr=False + ) + """Snapshot of execution-only approval overrides needed by `to_state()`.""" _conversation_id: str | None = field(default=None, repr=False) """Conversation identifier for server-managed runs.""" _previous_response_id: str | None = field(default=None, repr=False) diff --git a/src/agents/run_internal/agent_runner_helpers.py b/src/agents/run_internal/agent_runner_helpers.py index 1dcf7d5a42..0c4694ae6d 100644 --- a/src/agents/run_internal/agent_runner_helpers.py +++ b/src/agents/run_internal/agent_runner_helpers.py @@ -2,6 +2,7 @@ from __future__ import annotations +import copy from typing import Any, cast from ..agent import Agent @@ -185,9 +186,16 @@ def resolve_trace_include_sensitive_data( run_config: RunConfig, run_config_was_supplied: bool, ) -> bool: - """Resolve whether traces may include sensitive data for this run.""" - if run_state is None or run_config_was_supplied: + """Resolve whether traces may include sensitive data for this run. + + Resumed runs preserve the stored setting unless the new RunConfig explicitly narrows it by + setting `trace_include_sensitive_data=False`. + """ + del run_config_was_supplied + if run_state is None: return run_config.trace_include_sensitive_data + if run_config.trace_include_sensitive_data is False: + return False return run_state._trace_include_sensitive_data @@ -295,9 +303,15 @@ def attach_run_state_metadata(result: RunResult, *, run_state: RunState | None) if run_state is None: return result - result._state = run_state result._current_turn_persisted_item_count = run_state._current_turn_persisted_item_count result._trace_state = run_state._trace_state + result._trace_include_sensitive_data_snapshot = run_state._trace_include_sensitive_data + result._session_history_mutations_snapshot = copy.deepcopy( + run_state.get_session_history_mutations() + ) + result._execution_only_approval_override_call_ids_snapshot = list( + run_state._execution_only_approval_override_call_ids + ) return result diff --git a/tests/memory/test_openai_responses_compaction_session.py b/tests/memory/test_openai_responses_compaction_session.py index 5b312eb72d..69bc1a3952 100644 --- a/tests/memory/test_openai_responses_compaction_session.py +++ b/tests/memory/test_openai_responses_compaction_session.py @@ -520,6 +520,81 @@ async def test_run_compaction_auto_uses_default_store_when_unset(self) -> None: assert second_kwargs.get("previous_response_id") == "resp-stored" assert "input" not in second_kwargs + @pytest.mark.asyncio + async def test_run_compaction_ignores_abandoned_unresolved_function_calls(self) -> None: + mock_session = self.create_mock_session() + items: list[TResponseInputItem] = [ + cast(TResponseInputItem, {"type": "message", "role": "user", "content": "first"}), + cast( + TResponseInputItem, + { + "type": "function_call", + "call_id": "call-abandoned", + "id": "fc_1", + "name": "test_tool", + "arguments": "{}", + }, + ), + cast(TResponseInputItem, {"type": "message", "role": "user", "content": "followup"}), + cast( + TResponseInputItem, + {"type": "message", "role": "assistant", "content": "latest response"}, + ), + ] + mock_session.get_items.return_value = items + + mock_compact_response = MagicMock() + mock_compact_response.output = [] + + mock_client = MagicMock() + mock_client.responses.compact = AsyncMock(return_value=mock_compact_response) + + session = OpenAIResponsesCompactionSession( + session_id="test", + underlying_session=mock_session, + client=mock_client, + compaction_mode="auto", + ) + + await session.run_compaction({"response_id": "resp-latest", "force": True}) + + mock_client.responses.compact.assert_called_once_with( + previous_response_id="resp-latest", + model="gpt-4.1", + ) + + @pytest.mark.asyncio + async def test_run_compaction_still_blocks_active_unresolved_function_calls(self) -> None: + mock_session = self.create_mock_session() + items: list[TResponseInputItem] = [ + cast(TResponseInputItem, {"type": "message", "role": "user", "content": "hello"}), + cast( + TResponseInputItem, + { + "type": "function_call", + "call_id": "call-pending", + "id": "fc_1", + "name": "test_tool", + "arguments": "{}", + }, + ), + ] + mock_session.get_items.return_value = items + + mock_client = MagicMock() + mock_client.responses.compact = AsyncMock() + + session = OpenAIResponsesCompactionSession( + session_id="test", + underlying_session=mock_session, + client=mock_client, + compaction_mode="auto", + ) + + await session.run_compaction({"response_id": "resp-pending", "force": True}) + + mock_client.responses.compact.assert_not_called() + @pytest.mark.asyncio async def test_run_compaction_auto_uses_input_when_last_response_unstored(self) -> None: mock_session = self.create_mock_session() diff --git a/tests/test_agent_tracing.py b/tests/test_agent_tracing.py index 920896b90c..bdc75044fd 100644 --- a/tests/test_agent_tracing.py +++ b/tests/test_agent_tracing.py @@ -410,6 +410,45 @@ def send_email(recipient: str) -> str: assert function_span["span_data"]["output"] is None +@pytest.mark.asyncio +async def test_resumed_run_preserves_sensitive_trace_flag_for_unrelated_run_config() -> None: + model = FakeModel() + + @function_tool(name_override="send_email", needs_approval=True) + def send_email(recipient: str) -> str: + return recipient + + agent = Agent(name="trace_agent", model=model, tools=[send_email]) + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "send_email", '{"recipient":"alice@example.com"}', call_id="call-1" + ) + ], + [get_text_message("done")], + ] + ) + + first = await Runner.run(agent, input="first_test") + assert first.interruptions + + state = first.to_state() + state.set_trace_include_sensitive_data(False) + state.approve(first.interruptions[0], override_arguments={"recipient": "bob@example.com"}) + + resumed = await Runner.run( + agent, + state, + run_config=RunConfig(workflow_name="override_workflow"), + ) + + assert resumed.final_output == "done" + function_span = _get_last_function_span_export("send_email") + assert function_span["span_data"]["input"] is None + assert function_span["span_data"]["output"] is None + + @pytest.mark.asyncio async def test_wrapped_trace_is_single_trace(): model = FakeModel() diff --git a/tests/test_result_cast.py b/tests/test_result_cast.py index a97bb3eb24..9628b2d2a7 100644 --- a/tests/test_result_cast.py +++ b/tests/test_result_cast.py @@ -15,12 +15,16 @@ MessageOutputItem, RunContextWrapper, RunItem, + Runner, RunResult, RunResultStreaming, ) from agents.exceptions import AgentsException from agents.tool_context import ToolContext +from .fake_model import FakeModel +from .test_responses import get_text_message + def create_run_result( final_output: Any | None, @@ -261,6 +265,25 @@ def test_run_result_streaming_release_agents_releases_current_agent() -> None: _ = streaming_result.last_agent +@pytest.mark.asyncio +async def test_runner_result_does_not_retain_live_run_state() -> None: + agent = Agent( + name="runner-result-agent", + model=FakeModel(initial_output=[get_text_message("done")]), + ) + + result = await Runner.run(agent, "hello") + + assert result._state is None + + agent_ref = weakref.ref(agent) + result.release_agents() + del agent + gc.collect() + + assert agent_ref() is None + + def test_run_result_agent_tool_invocation_returns_none_for_plain_context() -> None: result = create_run_result("ok") From a426755cb8c58afb33044eaeeb946f157810b036 Mon Sep 17 00:00:00 2001 From: Kazuhiro Sera Date: Tue, 17 Mar 2026 16:18:14 +0900 Subject: [PATCH 5/5] fix review comments --- src/agents/run.py | 16 ++----- src/agents/run_config.py | 29 ++++++++++-- .../run_internal/agent_runner_helpers.py | 25 +++++------ tests/test_agent_runner.py | 45 ++++++++++++++++--- tests/test_agent_runner_streamed.py | 41 +++++++++++++++-- tests/test_agent_tracing.py | 45 +++++++++++++++++++ tests/test_run_config.py | 15 +++++++ 7 files changed, 177 insertions(+), 39 deletions(-) diff --git a/src/agents/run.py b/src/agents/run.py index ce33aa6c3d..24ca44cf22 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -29,7 +29,7 @@ ) from .lifecycle import RunHooks from .logger import logger -from .memory import Session, is_server_managed_conversation_session +from .memory import Session from .result import RunResult, RunResultStreaming from .run_config import ( DEFAULT_MAX_TURNS, @@ -413,7 +413,6 @@ async def run( auto_previous_response_id = kwargs.get("auto_previous_response_id", False) conversation_id = kwargs.get("conversation_id") session = kwargs.get("session") - run_config_was_supplied = run_config is not None if run_config is None: run_config = RunConfig() @@ -504,19 +503,15 @@ async def run( or previous_response_id is not None or auto_previous_response_id ) - history_is_server_managed = ( - server_manages_conversation or is_server_managed_conversation_session(session) - ) validate_override_history_persistence_support( input=input, session=session, - history_is_server_managed=history_is_server_managed, + response_history_is_server_managed=server_manages_conversation, ) resolved_trace_include_sensitive_data = resolve_trace_include_sensitive_data( run_state=run_state, run_config=run_config, - run_config_was_supplied=run_config_was_supplied, ) if resolved_trace_include_sensitive_data != run_config.trace_include_sensitive_data: run_config = dataclasses.replace( @@ -1469,7 +1464,6 @@ def run_streamed( auto_previous_response_id = kwargs.get("auto_previous_response_id", False) conversation_id = kwargs.get("conversation_id") session = kwargs.get("session") - run_config_was_supplied = run_config is not None if run_config is None: run_config = RunConfig() @@ -1553,18 +1547,14 @@ def run_streamed( or previous_response_id is not None or auto_previous_response_id ) - history_is_server_managed = ( - server_manages_conversation or is_server_managed_conversation_session(session) - ) validate_override_history_persistence_support( input=input, session=session, - history_is_server_managed=history_is_server_managed, + response_history_is_server_managed=server_manages_conversation, ) resolved_trace_include_sensitive_data = resolve_trace_include_sensitive_data( run_state=run_state, run_config=run_config, - run_config_was_supplied=run_config_was_supplied, ) if resolved_trace_include_sensitive_data != run_config.trace_include_sensitive_data: run_config = dataclasses.replace( diff --git a/src/agents/run_config.py b/src/agents/run_config.py index ad21f6c3b9..2e01b1a015 100644 --- a/src/agents/run_config.py +++ b/src/agents/run_config.py @@ -2,7 +2,7 @@ import os from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Optional +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Optional, cast from typing_extensions import NotRequired, TypedDict @@ -33,6 +33,14 @@ def _default_trace_include_sensitive_data() -> bool: return val.strip().lower() in ("1", "true", "yes", "on") +_TRACE_INCLUDE_SENSITIVE_DATA_UNSET = cast(bool, object()) + + +def _unset_trace_include_sensitive_data() -> bool: + """Return a sentinel so RunConfig can detect explicit trace flag overrides.""" + return _TRACE_INCLUDE_SENSITIVE_DATA_UNSET + + @dataclass class ModelInputData: """Container for the data that will be sent to the model.""" @@ -129,9 +137,7 @@ class RunConfig: tracing: TracingConfig | None = None """Tracing configuration for this run.""" - trace_include_sensitive_data: bool = field( - default_factory=_default_trace_include_sensitive_data - ) + trace_include_sensitive_data: bool = field(default_factory=_unset_trace_include_sensitive_data) """Whether we include potentially sensitive data (for example: inputs/outputs of tool calls or LLM generations) in traces. If False, we'll still create spans for these events, but the sensitive data will not be included. @@ -191,6 +197,21 @@ class RunConfig: - ``"omit"`` strips reasoning item IDs from model input built by the runner. """ + _trace_include_sensitive_data_was_explicit: bool = field( + init=False, + repr=False, + compare=False, + default=False, + ) + + def __post_init__(self) -> None: + if self.trace_include_sensitive_data is _TRACE_INCLUDE_SENSITIVE_DATA_UNSET: + self.trace_include_sensitive_data = _default_trace_include_sensitive_data() + self._trace_include_sensitive_data_was_explicit = False + return + + self._trace_include_sensitive_data_was_explicit = True + class RunOptions(TypedDict, Generic[TContext]): """Arguments for ``AgentRunner`` methods.""" diff --git a/src/agents/run_internal/agent_runner_helpers.py b/src/agents/run_internal/agent_runner_helpers.py index 0c4694ae6d..91f5073116 100644 --- a/src/agents/run_internal/agent_runner_helpers.py +++ b/src/agents/run_internal/agent_runner_helpers.py @@ -115,23 +115,26 @@ def validate_override_history_persistence_support( *, input: str | list[TResponseInputItem] | RunState[Any], session: Session | None, - history_is_server_managed: bool, + response_history_is_server_managed: bool, ) -> None: """Fail fast when approval override persistence requirements are not satisfied.""" if not isinstance(input, RunState): return - if input.has_pending_execution_only_approval_overrides() and not history_is_server_managed: + if ( + input.has_pending_execution_only_approval_overrides() + and not response_history_is_server_managed + ): raise UserError( "save_override_arguments=False is only supported when using conversation_id, " - "previous_response_id, auto_previous_response_id, or a server-managed session." + "previous_response_id, or auto_previous_response_id." ) mutations = input.get_session_history_mutations() if not mutations: return - if history_is_server_managed: + if response_history_is_server_managed: raise UserError( "save_override_arguments requires local canonical history. " "Server-managed conversations cannot persist corrected function_call arguments. " @@ -184,18 +187,14 @@ def resolve_trace_include_sensitive_data( *, run_state: RunState[TContext] | None, run_config: RunConfig, - run_config_was_supplied: bool, ) -> bool: - """Resolve whether traces may include sensitive data for this run. - - Resumed runs preserve the stored setting unless the new RunConfig explicitly narrows it by - setting `trace_include_sensitive_data=False`. - """ - del run_config_was_supplied + """Resolve whether traces may include sensitive data for this run.""" if run_state is None: return run_config.trace_include_sensitive_data - if run_config.trace_include_sensitive_data is False: - return False + + if getattr(run_config, "_trace_include_sensitive_data_was_explicit", True): + return run_config.trace_include_sensitive_data + return run_state._trace_include_sensitive_data diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index d92bcfd025..2461db9b50 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -1173,7 +1173,7 @@ def approval_tool(test: str) -> str: @pytest.mark.asyncio -async def test_resume_supports_execution_only_override_with_server_managed_session() -> None: +async def test_resume_rejects_execution_only_override_with_marker_session() -> None: model = FakeModel() @function_tool(name_override="approval_tool", needs_approval=True) @@ -1207,12 +1207,47 @@ def approval_tool(test: str) -> str: save_override_arguments=False, ) - resumed = await Runner.run(agent, state, session=session) + with pytest.raises(UserError, match="save_override_arguments=False is only supported"): + await Runner.run(agent, state, session=session) + + +@pytest.mark.asyncio +async def test_resume_supports_execution_only_override_with_previous_response_id() -> None: + model = FakeModel() + + @function_tool(name_override="approval_tool", needs_approval=True) + def approval_tool(test: str) -> str: + return f"result:{test}" + + agent = Agent( + name="approval_agent", + model=model, + tools=[approval_tool], + tool_use_behavior="stop_on_first_tool", + ) + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "approval_tool", json.dumps({"test": "foo"}), call_id="call-1" + ) + ], + ] + ) + + first = await Runner.run(agent, input="user_message", previous_response_id="resp-root") + assert first.interruptions + + state = first.to_state() + state.approve( + first.interruptions[0], + override_arguments={"test": "bar"}, + save_override_arguments=False, + ) + + resumed = await Runner.run(agent, state) assert resumed.final_output == "result:bar" - saved_items = await session.get_items() - assert cast(dict[str, Any], saved_items[1])["arguments"] == json.dumps({"test": "foo"}) - assert saved_items[2]["type"] == "function_call_output" @pytest.mark.asyncio diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index 1abfb28681..bf91699f08 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -1608,9 +1608,7 @@ async def test_tool(test: str) -> str: @pytest.mark.asyncio -async def test_streaming_resume_supports_execution_only_override_with_server_managed_session() -> ( - None -): +async def test_streaming_resume_rejects_execution_only_override_with_marker_session() -> None: async def test_tool(test: str) -> str: return f"result:{test}" @@ -1639,7 +1637,42 @@ async def test_tool(test: str) -> str: save_override_arguments=False, ) - resumed = Runner.run_streamed(agent, state, session=session) + with pytest.raises(UserError, match="save_override_arguments=False is only supported"): + Runner.run_streamed(agent, state, session=session) + + +@pytest.mark.asyncio +async def test_streaming_resume_supports_execution_only_override_with_previous_response_id() -> ( + None +): + async def test_tool(test: str) -> str: + return f"result:{test}" + + tool = function_tool(test_tool, name_override="test_tool", needs_approval=True) + model = FakeModel() + agent = Agent( + name="test", + model=model, + tools=[tool], + tool_use_behavior="stop_on_first_tool", + ) + + model.add_multiple_turn_outputs( + [[get_function_tool_call("test_tool", json.dumps({"test": "foo"}), call_id="call-resume")]] + ) + + first = Runner.run_streamed(agent, input="Use test_tool", previous_response_id="resp-root") + await consume_stream(first) + assert first.interruptions + + state = first.to_state() + state.approve( + first.interruptions[0], + override_arguments={"test": "bar"}, + save_override_arguments=False, + ) + + resumed = Runner.run_streamed(agent, state) await consume_stream(resumed) assert resumed.final_output == "result:bar" diff --git a/tests/test_agent_tracing.py b/tests/test_agent_tracing.py index bdc75044fd..e6db02fc03 100644 --- a/tests/test_agent_tracing.py +++ b/tests/test_agent_tracing.py @@ -777,6 +777,51 @@ def send_email(recipient: str) -> str: assert function_span["span_data"]["output"] is None +@pytest.mark.asyncio +async def test_resumed_streaming_run_preserves_sensitive_trace_flag_for_unrelated_run_config() -> ( + None +): + model = FakeModel() + + @function_tool(name_override="send_email", needs_approval=True) + def send_email(recipient: str) -> str: + return recipient + + agent = Agent(name="trace_agent", model=model, tools=[send_email]) + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call( + "send_email", '{"recipient":"alice@example.com"}', call_id="call-1" + ) + ], + [get_text_message("done")], + ] + ) + + first = Runner.run_streamed(agent, input="first_test") + async for _ in first.stream_events(): + pass + assert first.interruptions + + state = first.to_state() + state.set_trace_include_sensitive_data(False) + state.approve(first.interruptions[0], override_arguments={"recipient": "bob@example.com"}) + + resumed = Runner.run_streamed( + agent, + state, + run_config=RunConfig(workflow_name="override_workflow"), + ) + async for _ in resumed.stream_events(): + pass + + assert resumed.final_output == "done" + function_span = _get_last_function_span_export("send_email") + assert function_span["span_data"]["input"] is None + assert function_span["span_data"]["output"] is None + + @pytest.mark.asyncio async def test_wrapped_streaming_trace_is_single_trace(): model = FakeModel() diff --git a/tests/test_run_config.py b/tests/test_run_config.py index 31d6d0a46a..34ab570a3d 100644 --- a/tests/test_run_config.py +++ b/tests/test_run_config.py @@ -138,3 +138,18 @@ def test_trace_include_sensitive_data_explicit_override_takes_precedence(monkeyp monkeypatch.setenv("OPENAI_AGENTS_TRACE_INCLUDE_SENSITIVE_DATA", "true") config = RunConfig(trace_include_sensitive_data=False) assert config.trace_include_sensitive_data is False + + +def test_trace_include_sensitive_data_tracks_explicit_overrides(monkeypatch): + """RunConfig should distinguish explicit trace flag overrides from unrelated options.""" + monkeypatch.setenv("OPENAI_AGENTS_TRACE_INCLUDE_SENSITIVE_DATA", "true") + + default_config = RunConfig() + unrelated_config = RunConfig(workflow_name="custom-workflow") + explicit_true_config = RunConfig(trace_include_sensitive_data=True) + explicit_false_config = RunConfig(trace_include_sensitive_data=False) + + assert default_config._trace_include_sensitive_data_was_explicit is False + assert unrelated_config._trace_include_sensitive_data_was_explicit is False + assert explicit_true_config._trace_include_sensitive_data_was_explicit is True + assert explicit_false_config._trace_include_sensitive_data_was_explicit is True