From cdf4f0317f851e9e699cea5897a553fade5a2cc1 Mon Sep 17 00:00:00 2001 From: jawwad-ali Date: Sat, 6 Jun 2026 15:32:03 +0500 Subject: [PATCH] feat: pass run context wrapper to session get_items/add_items Sessions can now opt into receiving the current run's RunContextWrapper by accepting a keyword-only wrapper parameter on get_items/add_items. The runner inspects the session's signature and only forwards the wrapper to implementations that accept it, so existing custom sessions keep working unchanged. Wrapper-style sessions (EncryptedSession, OpenAIResponsesCompactionSession) forward the wrapper to opted-in underlying sessions. This revives the approach from #2690, which addressed maintainer review feedback before going stale, rebased onto the current runner internals and extended to MongoDBSession and the guardrail-trip persistence path. Closes #2072 --- docs/sessions/index.md | 46 ++ examples/memory/file_session.py | 15 +- .../memory/advanced_sqlite_session.py | 10 +- .../extensions/memory/async_sqlite_session.py | 17 +- src/agents/extensions/memory/dapr_session.py | 15 +- .../extensions/memory/encrypt_session.py | 39 +- .../extensions/memory/mongodb_session.py | 15 +- src/agents/extensions/memory/redis_session.py | 15 +- .../extensions/memory/sqlalchemy_session.py | 15 +- .../memory/openai_conversations_session.py | 17 +- .../openai_responses_compaction_session.py | 29 +- src/agents/memory/session.py | 63 ++- src/agents/memory/sqlite_session.py | 17 +- src/agents/run.py | 24 +- .../run_internal/agent_runner_helpers.py | 2 + src/agents/run_internal/run_loop.py | 16 +- .../run_internal/session_persistence.py | 50 +- ...est_openai_responses_compaction_session.py | 43 +- tests/memory/test_session.py | 8 +- tests/memory/test_session_context_wrapper.py | 433 ++++++++++++++++++ tests/test_agent_as_tool.py | 14 +- tests/test_agent_runner.py | 117 ++++- tests/test_agent_runner_streamed.py | 16 +- tests/utils/simple_session.py | 24 +- 24 files changed, 984 insertions(+), 76 deletions(-) create mode 100644 tests/memory/test_session_context_wrapper.py diff --git a/docs/sessions/index.md b/docs/sessions/index.md index 8916f85fab..f36bf08d97 100644 --- a/docs/sessions/index.md +++ b/docs/sessions/index.md @@ -680,6 +680,52 @@ result = await Runner.run( ) ``` +### Accessing the run context in custom sessions + +Custom sessions can opt into receiving the current run's [`RunContextWrapper`][agents.run_context.RunContextWrapper] by accepting a keyword-only `wrapper` parameter on `get_items` and `add_items`. The runner passes the wrapper only when the session's signature accepts it, so existing session implementations keep working unchanged: + +```python +from typing import Any + +from agents.memory.session import SessionABC +from agents.items import TResponseInputItem +from agents.run_context import RunContextWrapper + +class ContextAwareSession(SessionABC): + """Session that scopes storage by data from the run context.""" + + def __init__(self, session_id: str): + self.session_id = session_id + + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + # Use wrapper.context (e.g. a user ID) to scope retrieval. + user_id = wrapper.context.user_id if wrapper is not None else None + return await self._load_items(user_id, limit) + + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + # Persist items together with data from the run context. + user_id = wrapper.context.user_id if wrapper is not None else None + await self._store_items(user_id, items) + + async def pop_item(self) -> TResponseInputItem | None: + ... + + async def clear_session(self) -> None: + ... +``` + +The `wrapper` parameter may be `None`, for example when session methods are called directly rather than through the runner, so implementations should always handle that case. Sessions that accept `**kwargs` on these methods also receive the wrapper through them. + ## Community session implementations The community has developed additional session implementations: diff --git a/examples/memory/file_session.py b/examples/memory/file_session.py index e62dbd167f..abd6d81717 100644 --- a/examples/memory/file_session.py +++ b/examples/memory/file_session.py @@ -15,6 +15,7 @@ from agents.memory.session import Session from agents.memory.session_settings import SessionSettings +from agents.run_context import RunContextWrapper class FileSession(Session): @@ -43,14 +44,24 @@ async def get_session_id(self) -> str: """Return the session id, creating one if needed.""" return await self._ensure_session_id() - async def get_items(self, limit: int | None = None) -> list[Any]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[Any]: session_id = await self._ensure_session_id() items = await self._read_items(session_id) if limit is not None and limit >= 0: return items[-limit:] return items - async def add_items(self, items: list[Any]) -> None: + async def add_items( + self, + items: list[Any], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: if not items: return session_id = await self._ensure_session_id() diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index 83c289bdf8..ab560c43a6 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -15,6 +15,7 @@ from ...items import TResponseInputItem from ...memory import SQLiteSession from ...memory.session_settings import SessionSettings, resolve_session_limit +from ...run_context import RunContextWrapper class AdvancedSQLiteSession(SQLiteSession): @@ -121,7 +122,12 @@ def _init_structure_tables(self): conn.commit() - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add items to the session. Args: @@ -160,6 +166,8 @@ async def get_items( self, limit: int | None = None, branch_id: str | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, ) -> list[TResponseInputItem]: """Get items from current or specified branch. diff --git a/src/agents/extensions/memory/async_sqlite_session.py b/src/agents/extensions/memory/async_sqlite_session.py index 27a23b1cbe..57c206b044 100644 --- a/src/agents/extensions/memory/async_sqlite_session.py +++ b/src/agents/extensions/memory/async_sqlite_session.py @@ -5,13 +5,14 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from pathlib import Path -from typing import cast +from typing import Any, cast import aiosqlite from ...items import TResponseInputItem from ...memory import SessionABC from ...memory.session_settings import SessionSettings, resolve_session_limit +from ...run_context import RunContextWrapper class AsyncSQLiteSession(SessionABC): @@ -106,7 +107,12 @@ async def _locked_connection(self) -> AsyncIterator[aiosqlite.Connection]: conn = await self._get_connection() yield conn - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -156,7 +162,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/extensions/memory/dapr_session.py b/src/agents/extensions/memory/dapr_session.py index 6ac68f6020..2e24c3f2ff 100644 --- a/src/agents/extensions/memory/dapr_session.py +++ b/src/agents/extensions/memory/dapr_session.py @@ -29,6 +29,7 @@ import time from typing import Any, Final, Literal +from ...run_context import RunContextWrapper from ._optional_imports import raise_optional_dependency_error try: @@ -250,7 +251,12 @@ async def _handle_concurrency_conflict(self, error: Exception, attempt: int) -> # Session protocol implementation # ------------------------------------------------------------------ - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -289,7 +295,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: continue return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/extensions/memory/encrypt_session.py b/src/agents/extensions/memory/encrypt_session.py index 19ba7a5683..4ca7192114 100644 --- a/src/agents/extensions/memory/encrypt_session.py +++ b/src/agents/extensions/memory/encrypt_session.py @@ -37,8 +37,9 @@ from typing_extensions import TypedDict from ...items import TResponseInputItem -from ...memory.session import SessionABC +from ...memory.session import SessionABC, session_method_accepts_wrapper from ...memory.session_settings import SessionSettings, resolve_session_limit +from ...run_context import RunContextWrapper class EncryptedEnvelope(TypedDict): @@ -180,12 +181,28 @@ def _unwrap_valid_items( valid_items.append(item) return valid_items - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def _get_underlying_items( + self, limit: int | None, wrapper: RunContextWrapper[Any] | None + ) -> list[TResponseInputItem]: + # Forward the wrapper only when the underlying session opts in, so wrapping older + # custom sessions keeps working unchanged. + if wrapper is not None and session_method_accepts_wrapper( + self.underlying_session.get_items + ): + return await self.underlying_session.get_items(limit, wrapper=wrapper) + return await self.underlying_session.get_items(limit) + + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: effective_limit = resolve_session_limit(limit, self.session_settings) if effective_limit is not None and effective_limit > 0: window = effective_limit while True: - encrypted_items = await self.underlying_session.get_items(window) + encrypted_items = await self._get_underlying_items(window, wrapper) valid_items = self._unwrap_valid_items(encrypted_items) if len(valid_items) >= effective_limit: return valid_items[-effective_limit:] @@ -193,11 +210,23 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return valid_items window *= 2 - encrypted_items = await self.underlying_session.get_items(limit) + encrypted_items = await self._get_underlying_items(limit, wrapper) return self._unwrap_valid_items(encrypted_items) - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items] + if wrapper is not None and session_method_accepts_wrapper( + self.underlying_session.add_items + ): + await self.underlying_session.add_items( + cast(list[TResponseInputItem], wrapped), wrapper=wrapper + ) + return await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped)) async def pop_item(self) -> TResponseInputItem | None: diff --git a/src/agents/extensions/memory/mongodb_session.py b/src/agents/extensions/memory/mongodb_session.py index 113acdc6af..5a45ca8599 100644 --- a/src/agents/extensions/memory/mongodb_session.py +++ b/src/agents/extensions/memory/mongodb_session.py @@ -37,6 +37,7 @@ from datetime import datetime, timezone from typing import Any +from ...run_context import RunContextWrapper from ._optional_imports import raise_optional_dependency_error try: @@ -247,7 +248,12 @@ async def _deserialize_item(self, raw: str) -> TResponseInputItem: # Session protocol implementation # ------------------------------------------------------------------ - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -289,7 +295,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/extensions/memory/redis_session.py b/src/agents/extensions/memory/redis_session.py index 11e2dd838b..c20dcf72a2 100644 --- a/src/agents/extensions/memory/redis_session.py +++ b/src/agents/extensions/memory/redis_session.py @@ -26,6 +26,7 @@ import time from typing import Any +from ...run_context import RunContextWrapper from ._optional_imports import raise_optional_dependency_error try: @@ -145,7 +146,12 @@ async def _set_ttl_if_configured(self, *keys: str) -> None: # Session protocol implementation # ------------------------------------------------------------------ - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -184,7 +190,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index fd2502e24b..dbfa042e5f 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -51,6 +51,7 @@ from ...items import TResponseInputItem from ...memory.session import SessionABC from ...memory.session_settings import SessionSettings, resolve_session_limit +from ...run_context import RunContextWrapper class SQLAlchemySession(SessionABC): @@ -274,7 +275,12 @@ async def _ensure_tables(self) -> None: finally: self._init_lock.release() - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -326,7 +332,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: continue return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/memory/openai_conversations_session.py b/src/agents/memory/openai_conversations_session.py index 4d4fbaf635..e05caf88de 100644 --- a/src/agents/memory/openai_conversations_session.py +++ b/src/agents/memory/openai_conversations_session.py @@ -1,10 +1,13 @@ from __future__ import annotations +from typing import Any + from openai import AsyncOpenAI from agents.models._openai_shared import get_default_openai_client from ..items import TResponseInputItem +from ..run_context import RunContextWrapper from .session import SessionABC from .session_settings import SessionSettings, resolve_session_limit @@ -70,7 +73,12 @@ async def _get_session_id(self) -> str: async def _clear_session_id(self) -> None: self._session_id = None - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: session_id = await self._get_session_id() session_limit = resolve_session_limit(limit, self.session_settings) @@ -97,7 +105,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return all_items # type: ignore - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: session_id = await self._get_session_id() if not items: return diff --git a/src/agents/memory/openai_responses_compaction_session.py b/src/agents/memory/openai_responses_compaction_session.py index c112b706a1..48e2309156 100644 --- a/src/agents/memory/openai_responses_compaction_session.py +++ b/src/agents/memory/openai_responses_compaction_session.py @@ -8,12 +8,14 @@ from ..items import TResponseInputItem from ..models._openai_shared import get_default_openai_client +from ..run_context import RunContextWrapper from ..run_internal.items import normalize_input_items_for_api from .openai_conversations_session import OpenAIConversationsSession from .session import ( OpenAIResponsesCompactionArgs, OpenAIResponsesCompactionAwareSession, SessionABC, + session_method_accepts_wrapper, ) if TYPE_CHECKING: @@ -233,7 +235,18 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None f"candidates={len(self._compaction_candidate_items)})" ) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + # Forward the wrapper only when the underlying session opts in, so wrapping older + # custom sessions keeps working unchanged. + if wrapper is not None and session_method_accepts_wrapper( + self.underlying_session.get_items + ): + return await self.underlying_session.get_items(limit, wrapper=wrapper) return await self.underlying_session.get_items(limit) async def _get_all_underlying_session_items(self) -> list[TResponseInputItem]: @@ -331,8 +344,18 @@ def _get_deferred_compaction_response_id(self) -> str | None: def _clear_deferred_compaction(self) -> None: self._deferred_response_id = None - async def add_items(self, items: list[TResponseInputItem]) -> None: - await self.underlying_session.add_items(items) + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + if wrapper is not None and session_method_accepts_wrapper( + self.underlying_session.add_items + ): + await self.underlying_session.add_items(items, wrapper=wrapper) + else: + await self.underlying_session.add_items(items) if self._compaction_candidate_items is not None: new_items = _normalize_compaction_session_items(items) new_candidates = select_compaction_candidate_items(new_items) diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 1781b7ac9f..68ec558b93 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -1,12 +1,14 @@ from __future__ import annotations +import inspect from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Literal, Protocol, TypeGuard, runtime_checkable +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeGuard, runtime_checkable from typing_extensions import TypedDict if TYPE_CHECKING: from ..items import TResponseInputItem + from ..run_context import RunContextWrapper from .session_settings import SessionSettings @@ -21,23 +23,37 @@ class Session(Protocol): session_id: str session_settings: SessionSettings | None = None - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: limit: Maximum number of items to retrieve. If None, retrieves all items. When specified, returns the latest N items in chronological order. + wrapper: Optional run context wrapper for the current run. Implementations + may ignore this parameter. Returns: List of input items representing the conversation history """ ... - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: items: List of input items to add to the history + wrapper: Optional run context wrapper for the current run. Implementations + may ignore this parameter. """ ... @@ -68,12 +84,19 @@ class SessionABC(ABC): session_settings: SessionSettings | None = None @abstractmethod - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: limit: Maximum number of items to retrieve. If None, retrieves all items. When specified, returns the latest N items in chronological order. + wrapper: Optional run context wrapper for the current run. Implementations + may ignore this parameter. Returns: List of input items representing the conversation history @@ -81,11 +104,18 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: ... @abstractmethod - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: items: List of input items to add to the history + wrapper: Optional run context wrapper for the current run. Implementations + may ignore this parameter. """ ... @@ -148,3 +178,26 @@ def is_openai_responses_compaction_aware_session( except Exception: return False return callable(run_compaction) + + +def session_method_accepts_wrapper(method: Any) -> bool: + """Check if a session method accepts the keyword-only ``wrapper`` argument. + + The runner (and wrapper-style sessions such as ``EncryptedSession``) use this to pass + the current run context only to implementations that opt in, so older custom sessions + that predate the ``wrapper`` parameter keep working unchanged. Methods that accept + ``**kwargs`` are treated as opted in and receive the wrapper through them. + """ + try: + parameters = tuple(inspect.signature(method).parameters.values()) + except (TypeError, ValueError): + return False + return any( + parameter.kind is inspect.Parameter.VAR_KEYWORD + or ( + parameter.name == "wrapper" + and parameter.kind + in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + ) + for parameter in parameters + ) diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 3a69f9883a..695c3dfb7f 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -7,9 +7,10 @@ from collections.abc import Iterator from contextlib import contextmanager from pathlib import Path -from typing import ClassVar +from typing import Any, ClassVar from ..items import TResponseInputItem +from ..run_context import RunContextWrapper from .session import SessionABC from .session_settings import SessionSettings, resolve_session_limit @@ -199,7 +200,12 @@ def _insert_items(self, conn: sqlite3.Connection, items: list[TResponseInputItem (self.session_id,), ) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -254,7 +260,12 @@ def _get_items_sync(): return await asyncio.to_thread(_get_items_sync) - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/run.py b/src/agents/run.py index 014271a5ea..27e3c9b039 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -101,6 +101,7 @@ NextStepRunAgain, ) from .run_internal.session_persistence import ( + _session_get_items, persist_session_items_for_guardrail_trip, prepare_input_with_session, resumed_turn_items, @@ -510,6 +511,10 @@ async def run( raw_input = cast(str | list[TResponseInputItem], input) original_user_input = raw_input + context_wrapper = ensure_context_wrapper(context) + context = context_wrapper.context + set_agent_tool_state_scope(context_wrapper, None) + validate_session_conversation_settings( session, conversation_id=conversation_id, @@ -531,6 +536,7 @@ async def run( run_config.session_settings, include_history_in_prepared_input=False, preserve_dropped_new_items=True, + wrapper=context_wrapper, ) original_input_for_state = raw_input session_input_items_for_persistence = [] @@ -543,6 +549,7 @@ async def run( session, run_config.session_input_callback, run_config.session_settings, + wrapper=context_wrapper, ) original_input_for_state = prepared_input @@ -579,7 +586,7 @@ async def run( session_input_items: list[TResponseInputItem] | None = None if session is not None: try: - session_input_items = await session.get_items() + session_input_items = await _session_get_items(session, wrapper=context_wrapper) except Exception: session_input_items = None server_conversation_tracker.hydrate_from_state( @@ -628,8 +635,6 @@ async def run( generated_items = [] session_items = [] model_responses = [] - context_wrapper = ensure_context_wrapper(context) - set_agent_tool_state_scope(context_wrapper, None) run_state = RunState( context=context_wrapper, original_input=original_input, @@ -754,6 +759,7 @@ def _finalize_result(result: RunResult) -> RunResult: [], run_state, store=store_setting, + wrapper=context_wrapper, ) session_input_items_for_persistence = [] except BaseException: @@ -796,6 +802,7 @@ def _finalize_result(result: RunResult) -> RunResult: original_user_input, run_state, store=store_setting, + wrapper=context_wrapper, ) ) raise @@ -835,6 +842,7 @@ def _finalize_result(result: RunResult) -> RunResult: [], run_state, store=store_setting, + wrapper=context_wrapper, ) session_input_items_for_persistence = [] if run_state is not None and run_state._current_step is not None: @@ -893,6 +901,7 @@ def _finalize_result(result: RunResult) -> RunResult: run_state._reasoning_item_id_policy ), store=store_setting, + wrapper=context_wrapper, ) ) @@ -1005,6 +1014,7 @@ def _finalize_result(result: RunResult) -> RunResult: run_state, response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) result._original_input = copy_input_items(original_input) return _finalize_result(result) @@ -1139,6 +1149,7 @@ def _finalize_result(result: RunResult) -> RunResult: run_state, response_id=None, store=store_setting, + wrapper=context_wrapper, ) result._original_input = copy_input_items(original_input) return _finalize_result(result) @@ -1187,6 +1198,7 @@ def _finalize_result(result: RunResult) -> RunResult: original_user_input, run_state, store=store_setting, + wrapper=context_wrapper, ) ) raise @@ -1240,6 +1252,7 @@ def _finalize_result(result: RunResult) -> RunResult: original_user_input, run_state, store=store_setting, + wrapper=context_wrapper, ) ) raise @@ -1347,6 +1360,7 @@ def _finalize_result(result: RunResult) -> RunResult: run_state._reasoning_item_id_policy ), store=store_setting, + wrapper=context_wrapper, ) run_state._current_turn_persisted_item_count += saved_count else: @@ -1357,6 +1371,7 @@ def _finalize_result(result: RunResult) -> RunResult: run_state, response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) # After the first resumed turn, treat subsequent turns as fresh @@ -1409,6 +1424,7 @@ def _finalize_result(result: RunResult) -> RunResult: items=session_items_for_turn(turn_result), response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) result._original_input = copy_input_items(original_input) return _finalize_result(result) @@ -1428,6 +1444,7 @@ def _finalize_result(result: RunResult) -> RunResult: run_state, response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) append_model_response_if_new( model_responses, turn_result.model_response @@ -1489,6 +1506,7 @@ def _finalize_result(result: RunResult) -> RunResult: items=session_items_for_turn(turn_result), response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) continue else: diff --git a/src/agents/run_internal/agent_runner_helpers.py b/src/agents/run_internal/agent_runner_helpers.py index 84c67d6b8f..cad1201562 100644 --- a/src/agents/run_internal/agent_runner_helpers.py +++ b/src/agents/run_internal/agent_runner_helpers.py @@ -467,6 +467,7 @@ async def save_turn_items_if_needed( items: list[RunItem], response_id: str | None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> None: """Persist turn items when persistence is enabled and guardrails allow it.""" if not session_persistence_enabled: @@ -482,6 +483,7 @@ async def save_turn_items_if_needed( run_state, response_id=response_id, store=store, + wrapper=wrapper, ) diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 45f09c0fa0..45fe354be1 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -143,6 +143,7 @@ ToolRunShellCall, ) from .session_persistence import ( + _session_get_items, persist_session_items_for_guardrail_trip, prepare_input_with_session, resumed_turn_items, @@ -322,6 +323,7 @@ async def _save_resumed_stream_items( response_id=response_id, reasoning_item_id_policy=streamed_result._reasoning_item_id_policy, store=store, + wrapper=streamed_result.context_wrapper, ) if run_state is not None: run_state._current_turn_persisted_item_count = ( @@ -353,6 +355,7 @@ async def _save_stream_items( run_state, response_id=response_id, store=store, + wrapper=streamed_result.context_wrapper, ) if update_persisted_count and streamed_result._state is not None: streamed_result._current_turn_persisted_item_count = ( @@ -575,7 +578,7 @@ def _sync_conversation_tracking_from_tracker() -> None: session_items: list[TResponseInputItem] | None = None if session is not None: try: - session_items = await session.get_items() + session_items = await _session_get_items(session, wrapper=context_wrapper) except Exception: session_items = None server_conversation_tracker.hydrate_from_state( @@ -603,6 +606,7 @@ def _sync_conversation_tracking_from_tracker() -> None: run_config.session_settings, include_history_in_prepared_input=not server_manages_conversation, preserve_dropped_new_items=True, + wrapper=context_wrapper, ) streamed_result.input = prepared_input streamed_result._original_input = copy_input_items(prepared_input) @@ -706,6 +710,7 @@ async def _save_stream_items_without_count( store=current_agent.model_settings.resolve( run_config.model_settings ).store, + wrapper=context_wrapper, ) ) raise InputGuardrailTripwireTriggered(result) @@ -978,6 +983,7 @@ async def _save_stream_items_without_count( store=current_agent.model_settings.resolve( run_config.model_settings ).store, + wrapper=context_wrapper, ) ) raise InputGuardrailTripwireTriggered(result) @@ -1420,7 +1426,13 @@ def _tool_search_fingerprint(raw_item: Any) -> str: ) ] if input_items_to_save: - await save_result_to_session(session, input_items_to_save, [], streamed_result._state) + await save_result_to_session( + session, + input_items_to_save, + [], + streamed_result._state, + wrapper=context_wrapper, + ) previous_response_id = ( server_conversation_tracker.previous_response_id diff --git a/src/agents/run_internal/session_persistence.py b/src/agents/run_internal/session_persistence.py index f483da13a3..dc7cf136b5 100644 --- a/src/agents/run_internal/session_persistence.py +++ b/src/agents/run_internal/session_persistence.py @@ -23,6 +23,8 @@ is_openai_responses_compaction_aware_session, ) from ..memory.openai_conversations_session import OpenAIConversationsSession +from ..memory.session import session_method_accepts_wrapper +from ..run_context import RunContextWrapper from ..run_state import RunState from .items import ( ReasoningItemIdPolicy, @@ -51,6 +53,39 @@ ] +async def _session_get_items( + session: Session, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, +) -> list[TResponseInputItem]: + """Call ``session.get_items``, passing the run context only when the session opts in. + + Custom sessions that predate the keyword-only ``wrapper`` parameter keep working + unchanged because the wrapper is only forwarded when their signature accepts it. + """ + if wrapper is not None and session_method_accepts_wrapper(session.get_items): + if limit is not None: + return await session.get_items(limit=limit, wrapper=wrapper) + return await session.get_items(wrapper=wrapper) + if limit is not None: + return await session.get_items(limit=limit) + return await session.get_items() + + +async def _session_add_items( + session: Session, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, +) -> None: + """Call ``session.add_items``, passing the run context only when the session opts in.""" + if wrapper is not None and session_method_accepts_wrapper(session.add_items): + await session.add_items(items, wrapper=wrapper) + return + await session.add_items(items) + + async def prepare_input_with_session( input: str | list[TResponseInputItem], session: Session | None, @@ -59,6 +94,7 @@ async def prepare_input_with_session( *, include_history_in_prepared_input: bool = True, preserve_dropped_new_items: bool = False, + wrapper: RunContextWrapper[Any] | None = None, ) -> tuple[str | list[TResponseInputItem], list[TResponseInputItem]]: """Prepare model input from session history plus the new turn input. @@ -83,9 +119,9 @@ async def prepare_input_with_session( resolved_settings = resolved_settings.resolve(session_settings) if resolved_settings.limit is not None: - history = await session.get_items(limit=resolved_settings.limit) + history = await _session_get_items(session, limit=resolved_settings.limit, wrapper=wrapper) else: - history = await session.get_items() + history = await _session_get_items(session, wrapper=wrapper) is_openai_conversation_session = isinstance(session, OpenAIConversationsSession) converted_history = [ strip_internal_input_item_metadata(ensure_input_item_format(item)) for item in history @@ -194,6 +230,7 @@ async def persist_session_items_for_guardrail_trip( original_user_input: str | list[TResponseInputItem] | None, run_state: RunState | None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> list[TResponseInputItem] | None: """ Persist input items when a guardrail tripwire is triggered. @@ -208,7 +245,9 @@ async def persist_session_items_for_guardrail_trip( input_items_for_save: list[TResponseInputItem] = ( updated_session_input_items if updated_session_input_items is not None else [] ) - await save_result_to_session(session, input_items_for_save, [], run_state, store=store) + await save_result_to_session( + session, input_items_for_save, [], run_state, store=store, wrapper=wrapper + ) return updated_session_input_items @@ -253,6 +292,7 @@ async def save_result_to_session( response_id: str | None = None, reasoning_item_id_policy: ReasoningItemIdPolicy | None = None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> int: """ Persist a turn to the session store, keeping track of what was already saved so retries @@ -346,7 +386,7 @@ async def save_result_to_session( run_state._current_turn_persisted_item_count = already_persisted + saved_run_items_count return saved_run_items_count - await session.add_items(items_to_save) + await _session_add_items(session, items_to_save, wrapper=wrapper) if run_state: run_state._current_turn_persisted_item_count = already_persisted + saved_run_items_count @@ -397,6 +437,7 @@ async def save_resumed_turn_items( response_id: str | None, reasoning_item_id_policy: ReasoningItemIdPolicy | None = None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> int: """Persist resumed turn items and return the updated persisted count.""" if session is None or not items: @@ -409,6 +450,7 @@ async def save_resumed_turn_items( response_id=response_id, reasoning_item_id_policy=reasoning_item_id_policy, store=store, + wrapper=wrapper, ) return persisted_count + saved_count diff --git a/tests/memory/test_openai_responses_compaction_session.py b/tests/memory/test_openai_responses_compaction_session.py index fe893cf88a..3c2a85380c 100644 --- a/tests/memory/test_openai_responses_compaction_session.py +++ b/tests/memory/test_openai_responses_compaction_session.py @@ -22,6 +22,7 @@ is_openai_model_name, select_compaction_candidate_items, ) +from agents.run_context import RunContextWrapper from agents.run_internal.items import ( TOOL_CALL_SESSION_DESCRIPTION_KEY, TOOL_CALL_SESSION_TITLE_KEY, @@ -510,7 +511,12 @@ def __init__(self, history: list[TResponseInputItem]) -> None: self.add_calls = 0 self.clear_calls = 0 - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.add_calls += 1 if self.add_calls == 1: await super().add_items(items[:1]) @@ -566,12 +572,22 @@ def __init__(self, history: list[TResponseInputItem]) -> None: self.add_calls = 0 self.clear_calls = 0 - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: if limit is None and self.session_settings is not None: limit = self.session_settings.limit return await super().get_items(limit) - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.add_calls += 1 if self.add_calls == 1: await super().add_items(items[:1]) @@ -624,7 +640,12 @@ def __init__(self, history: list[TResponseInputItem]) -> None: self.add_calls = 0 self.clear_calls = 0 - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.add_calls += 1 await super().add_items(items) @@ -674,7 +695,12 @@ def __init__(self, history: list[TResponseInputItem]) -> None: self.add_calls = 0 self.clear_calls = 0 - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.add_calls += 1 await super().add_items(items) @@ -725,7 +751,12 @@ def __init__(self, history: list[TResponseInputItem]) -> None: self.add_calls = 0 self.clear_calls = 0 - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.add_calls += 1 if self.add_calls == 1: await super().add_items(items[:1]) diff --git a/tests/memory/test_session.py b/tests/memory/test_session.py index f9cc324d2e..995e0a25d0 100644 --- a/tests/memory/test_session.py +++ b/tests/memory/test_session.py @@ -4,10 +4,12 @@ import sqlite3 import tempfile from pathlib import Path +from typing import Any import pytest from agents import Agent, RunConfig, Runner, SQLiteSession, TResponseInputItem +from agents.run_context import RunContextWrapper from tests.fake_model import FakeModel from tests.test_responses import get_text_message @@ -640,7 +642,11 @@ async def test_session_add_items_exception_propagates_in_streamed(): """ session = SQLiteSession("test_exception_session") - async def _failing_add_items(_items): + async def _failing_add_items( + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: raise RuntimeError("Simulated session.add_items failure") session.add_items = _failing_add_items # type: ignore[method-assign] diff --git a/tests/memory/test_session_context_wrapper.py b/tests/memory/test_session_context_wrapper.py new file mode 100644 index 0000000000..8a56013c09 --- /dev/null +++ b/tests/memory/test_session_context_wrapper.py @@ -0,0 +1,433 @@ +"""Tests for passing the run context wrapper to Session methods (issue #2072).""" + +from __future__ import annotations + +import asyncio +import inspect +from dataclasses import dataclass +from typing import Any + +import pytest + +from agents import ( + Agent, + GuardrailFunctionOutput, + InputGuardrailTripwireTriggered, + RunConfig, + Runner, + SQLiteSession, + TResponseInputItem, + input_guardrail, +) +from agents.memory.session import SessionABC, session_method_accepts_wrapper +from agents.memory.session_settings import SessionSettings +from agents.run_context import RunContextWrapper +from tests.fake_model import FakeModel +from tests.test_responses import get_text_message + + +@dataclass +class UserInfo: + """Sample user-defined run context object.""" + + user_id: str = "user-123" + + +class ContextAwareSession(SessionABC): + """Session that opts into the wrapper parameter and records what it receives.""" + + def __init__(self, session_id: str = "context-aware"): + self.session_id = session_id + self._items: list[TResponseInputItem] = [] + self.get_items_wrappers: list[RunContextWrapper[Any] | None] = [] + self.get_items_limits: list[int | None] = [] + self.add_items_wrappers: list[RunContextWrapper[Any] | None] = [] + + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + self.get_items_wrappers.append(wrapper) + self.get_items_limits.append(limit) + if limit is not None: + return self._items[-limit:] + return list(self._items) + + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + self.add_items_wrappers.append(wrapper) + self._items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + return self._items.pop() if self._items else None + + async def clear_session(self) -> None: + self._items.clear() + + +class LegacySession(SessionABC): + """Session with pre-wrapper signatures, as third-party implementations may still have.""" + + def __init__(self, session_id: str = "legacy"): + self.session_id = session_id + self._items: list[TResponseInputItem] = [] + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: # type: ignore[override] + if limit is not None: + return self._items[-limit:] + return list(self._items) + + async def add_items(self, items: list[TResponseInputItem]) -> None: # type: ignore[override] + self._items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + return self._items.pop() if self._items else None + + async def clear_session(self) -> None: + self._items.clear() + + +class VarKwargsSession(LegacySession): + """Session that accepts the wrapper through ``**kwargs`` rather than a named parameter.""" + + def __init__(self, session_id: str = "var-kwargs"): + super().__init__(session_id) + self.received_kwargs: list[dict[str, Any]] = [] + + async def get_items(self, limit: int | None = None, **kwargs: Any) -> list[TResponseInputItem]: + self.received_kwargs.append(kwargs) + return await super().get_items(limit) + + async def add_items(self, items: list[TResponseInputItem], **kwargs: Any) -> None: + self.received_kwargs.append(kwargs) + await super().add_items(items) + + +def _run_sync_wrapper(agent, input_data, **kwargs): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + return Runner.run_sync(agent, input_data, **kwargs) + finally: + loop.close() + + +async def run_agent_async(runner_method: str, agent, input_data, **kwargs): + if runner_method == "run": + return await Runner.run(agent, input_data, **kwargs) + elif runner_method == "run_sync": + return await asyncio.to_thread(_run_sync_wrapper, agent, input_data, **kwargs) + elif runner_method == "run_streamed": + result = Runner.run_streamed(agent, input_data, **kwargs) + async for _ in result.stream_events(): + pass + return result + else: + raise ValueError(f"Unknown runner method: {runner_method}") + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_runner_passes_wrapper_to_context_aware_session(runner_method): + """Sessions that opt in receive the run context wrapper from every runner entrypoint.""" + session = ContextAwareSession() + context = UserInfo() + + model = FakeModel() + agent = Agent(name="test", model=model) + model.set_next_output([get_text_message("Hello")]) + + result = await run_agent_async( + runner_method, agent, "Hi there", session=session, context=context + ) + assert result.final_output == "Hello" + + assert len(session.get_items_wrappers) > 0 + assert len(session.add_items_wrappers) > 0 + for wrapper in session.get_items_wrappers + session.add_items_wrappers: + assert isinstance(wrapper, RunContextWrapper) + assert wrapper.context is context + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_runner_keeps_legacy_session_working(runner_method): + """Sessions without the wrapper parameter keep working unchanged.""" + session = LegacySession() + + model = FakeModel() + agent = Agent(name="test", model=model) + + model.set_next_output([get_text_message("San Francisco")]) + result1 = await run_agent_async( + runner_method, agent, "What city is the Golden Gate Bridge in?", session=session + ) + assert result1.final_output == "San Francisco" + + model.set_next_output([get_text_message("California")]) + result2 = await run_agent_async(runner_method, agent, "What state is it in?", session=session) + assert result2.final_output == "California" + + # The second turn must include the persisted history from the first turn. + last_input = model.last_turn_args["input"] + assert len(last_input) > 1 + + +@pytest.mark.asyncio +async def test_runner_passes_wrapper_to_var_kwargs_session(): + """Sessions accepting **kwargs receive the wrapper through them.""" + session = VarKwargsSession() + context = UserInfo() + + model = FakeModel() + agent = Agent(name="test", model=model) + model.set_next_output([get_text_message("Hello")]) + + await Runner.run(agent, "Hi there", session=session, context=context) + + wrappers = [kwargs["wrapper"] for kwargs in session.received_kwargs if "wrapper" in kwargs] + assert len(wrappers) > 0 + for wrapper in wrappers: + assert isinstance(wrapper, RunContextWrapper) + assert wrapper.context is context + + +@pytest.mark.asyncio +async def test_runner_without_explicit_context_passes_wrapper(): + """The wrapper is passed even when the caller does not provide a context object.""" + session = ContextAwareSession() + + model = FakeModel() + agent = Agent(name="test", model=model) + model.set_next_output([get_text_message("Hello")]) + + await Runner.run(agent, "Hi there", session=session) + + assert len(session.add_items_wrappers) > 0 + for wrapper in session.add_items_wrappers: + assert isinstance(wrapper, RunContextWrapper) + + +def test_session_method_accepts_wrapper_helper(): + """The capability check recognizes opt-in signatures and rejects legacy ones.""" + context_aware = ContextAwareSession() + legacy = LegacySession() + var_kwargs = VarKwargsSession() + + assert session_method_accepts_wrapper(context_aware.get_items) is True + assert session_method_accepts_wrapper(context_aware.add_items) is True + assert session_method_accepts_wrapper(var_kwargs.get_items) is True + assert session_method_accepts_wrapper(var_kwargs.add_items) is True + assert session_method_accepts_wrapper(legacy.get_items) is False + assert session_method_accepts_wrapper(legacy.add_items) is False + # Callables without an introspectable signature must not be treated as opted in. + assert session_method_accepts_wrapper(max) is False + + +@pytest.mark.asyncio +async def test_sqlite_session_accepts_and_ignores_wrapper(): + """Built-in sessions accept the wrapper directly and behave the same with or without it.""" + session = SQLiteSession("direct-call-test") + try: + wrapper = RunContextWrapper(context=UserInfo()) + items: list[TResponseInputItem] = [{"role": "user", "content": "hello"}] + + await session.add_items(items, wrapper=wrapper) + retrieved_with = await session.get_items(wrapper=wrapper) + retrieved_without = await session.get_items() + + assert retrieved_with == retrieved_without + assert len(retrieved_with) == 1 + finally: + session.close() + + +_BUILTIN_SESSION_SPECS = [ + ("agents.memory.sqlite_session", "SQLiteSession", None), + ("agents.memory.openai_conversations_session", "OpenAIConversationsSession", None), + ( + "agents.memory.openai_responses_compaction_session", + "OpenAIResponsesCompactionSession", + None, + ), + ("agents.extensions.memory.async_sqlite_session", "AsyncSQLiteSession", None), + ("agents.extensions.memory.advanced_sqlite_session", "AdvancedSQLiteSession", None), + ("agents.extensions.memory.encrypt_session", "EncryptedSession", "cryptography"), + ("agents.extensions.memory.redis_session", "RedisSession", "redis"), + ("agents.extensions.memory.sqlalchemy_session", "SQLAlchemySession", "sqlalchemy"), + ("agents.extensions.memory.dapr_session", "DaprSession", "dapr"), + ("agents.extensions.memory.mongodb_session", "MongoDBSession", "pymongo"), +] + + +@pytest.mark.parametrize( + "module_name,class_name,required_package", + _BUILTIN_SESSION_SPECS, + ids=[spec[1] for spec in _BUILTIN_SESSION_SPECS], +) +def test_builtin_sessions_expose_keyword_only_wrapper(module_name, class_name, required_package): + """Every built-in session implementation exposes the keyword-only wrapper parameter.""" + if required_package is not None: + pytest.importorskip(required_package) + module = pytest.importorskip(module_name) + session_cls = getattr(module, class_name) + + for method_name in ("get_items", "add_items"): + signature = inspect.signature(getattr(session_cls, method_name)) + parameter = signature.parameters.get("wrapper") + assert parameter is not None, f"{class_name}.{method_name} is missing wrapper" + assert parameter.kind is inspect.Parameter.KEYWORD_ONLY + assert parameter.default is None + + +@pytest.mark.asyncio +async def test_guardrail_trip_persists_input_with_wrapper(): + """The guardrail-trip persistence path forwards the wrapper to the session.""" + session = ContextAwareSession() + context = UserInfo() + + @input_guardrail + def always_trip(ctx, agent, input) -> GuardrailFunctionOutput: + return GuardrailFunctionOutput(output_info=None, tripwire_triggered=True) + + model = FakeModel() + agent = Agent(name="test", model=model, input_guardrails=[always_trip]) + model.set_next_output([get_text_message("never returned")]) + + with pytest.raises(InputGuardrailTripwireTriggered): + await Runner.run(agent, "Hi there", session=session, context=context) + + assert len(session.add_items_wrappers) > 0 + for wrapper in session.add_items_wrappers: + assert isinstance(wrapper, RunContextWrapper) + assert wrapper.context is context + + +@pytest.mark.asyncio +async def test_session_input_callback_path_passes_wrapper(): + """The history-merge callback path still forwards the wrapper on get_items.""" + session = ContextAwareSession() + context = UserInfo() + + def keep_everything( + history: list[TResponseInputItem], new_items: list[TResponseInputItem] + ) -> list[TResponseInputItem]: + return history + new_items + + model = FakeModel() + agent = Agent(name="test", model=model) + + model.set_next_output([get_text_message("first")]) + await Runner.run(agent, "Turn one", session=session, context=context) + + model.set_next_output([get_text_message("second")]) + await Runner.run( + agent, + "Turn two", + session=session, + context=context, + run_config=RunConfig(session_input_callback=keep_everything), + ) + + assert len(session.get_items_wrappers) >= 2 + for wrapper in session.get_items_wrappers: + assert isinstance(wrapper, RunContextWrapper) + assert wrapper.context is context + + +@pytest.mark.asyncio +async def test_session_settings_limit_path_passes_wrapper(): + """The limited-history read passes both the limit and the wrapper to the session.""" + session = ContextAwareSession() + context = UserInfo() + + model = FakeModel() + agent = Agent(name="test", model=model) + + model.set_next_output([get_text_message("first")]) + await Runner.run(agent, "Turn one", session=session, context=context) + + model.set_next_output([get_text_message("second")]) + await Runner.run( + agent, + "Turn two", + session=session, + context=context, + run_config=RunConfig(session_settings=SessionSettings(limit=1)), + ) + + # The second run's history read must use the limit and still carry the wrapper. + assert session.get_items_limits[-1] == 1 + last_wrapper = session.get_items_wrappers[-1] + assert isinstance(last_wrapper, RunContextWrapper) + assert last_wrapper.context is context + + +@pytest.mark.asyncio +async def test_encrypted_session_forwards_wrapper_to_underlying_session(): + """EncryptedSession forwards the wrapper to underlying sessions that opt in.""" + pytest.importorskip("cryptography") + from agents.extensions.memory.encrypt_session import EncryptedSession + + underlying = ContextAwareSession() + session = EncryptedSession( + session_id="enc-forward", + underlying_session=underlying, + encryption_key="test-key", + ) + wrapper = RunContextWrapper(context=UserInfo()) + items: list[TResponseInputItem] = [{"role": "user", "content": "hello"}] + + await session.add_items(items, wrapper=wrapper) + await session.get_items(wrapper=wrapper) + + assert underlying.add_items_wrappers == [wrapper] + assert len(underlying.get_items_wrappers) > 0 + assert all(received is wrapper for received in underlying.get_items_wrappers) + + +@pytest.mark.asyncio +async def test_encrypted_session_does_not_break_legacy_underlying_session(): + """EncryptedSession never passes the wrapper to underlying sessions that predate it.""" + pytest.importorskip("cryptography") + from agents.extensions.memory.encrypt_session import EncryptedSession + + underlying = LegacySession() + session = EncryptedSession( + session_id="enc-legacy", + underlying_session=underlying, + encryption_key="test-key", + ) + wrapper = RunContextWrapper(context=UserInfo()) + items: list[TResponseInputItem] = [{"role": "user", "content": "hello"}] + + await session.add_items(items, wrapper=wrapper) + retrieved = await session.get_items(wrapper=wrapper) + + assert len(retrieved) == 1 + + +@pytest.mark.asyncio +async def test_compaction_session_forwards_wrapper_to_underlying_session(): + """OpenAIResponsesCompactionSession forwards the wrapper to opted-in underlying sessions.""" + from agents.memory.openai_responses_compaction_session import ( + OpenAIResponsesCompactionSession, + ) + + underlying = ContextAwareSession() + session = OpenAIResponsesCompactionSession("compaction-forward", underlying) + wrapper = RunContextWrapper(context=UserInfo()) + items: list[TResponseInputItem] = [{"role": "user", "content": "hello"}] + + await session.add_items(items, wrapper=wrapper) + retrieved = await session.get_items(wrapper=wrapper) + + assert underlying.add_items_wrappers == [wrapper] + assert underlying.get_items_wrappers == [wrapper] + assert len(retrieved) == 1 diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index c5cc123034..737a56c304 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -328,10 +328,20 @@ class DummySession(Session): session_id = "sess_123" session_settings = SessionSettings() - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: return None async def pop_item(self) -> TResponseInputItem | None: diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index eb22c70f14..874cddfd1d 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -7,7 +7,7 @@ from collections.abc import Callable from pathlib import Path from typing import Any, cast -from unittest.mock import patch +from unittest.mock import ANY, call, patch import httpx import pytest @@ -2123,12 +2123,22 @@ class DummyOpenAIConversationsSession(OpenAIConversationsSession): def __init__(self, history: list[TResponseInputItem]) -> None: self.history = history - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: if limit is None: return list(self.history) return self.history[-limit:] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.history.extend(items) async def pop_item(self) -> TResponseInputItem | None: @@ -2227,12 +2237,22 @@ class DummyOpenAIConversationsSession(OpenAIConversationsSession): def __init__(self, history: list[TResponseInputItem]) -> None: self.history = history - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: if limit is None: return list(self.history) return self.history[-limit:] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.history.extend(items) async def pop_item(self) -> TResponseInputItem | None: @@ -2280,12 +2300,22 @@ class DummyOpenAIConversationsSession(OpenAIConversationsSession): def __init__(self, history: list[TResponseInputItem]) -> None: self.history = history - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: if limit is None: return list(self.history) return self.history[-limit:] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.history.extend(items) async def pop_item(self) -> TResponseInputItem | None: @@ -2365,7 +2395,12 @@ def __init__(self) -> None: super().__init__() self.get_items_calls = 0 - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: self.get_items_calls += 1 if self.get_items_calls == 1: raise RuntimeError("temporary failure") @@ -2720,10 +2755,20 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.saved_items.extend(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] async def pop_item(self) -> TResponseInputItem | None: @@ -2807,10 +2852,20 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.saved_items.extend(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] async def pop_item(self) -> TResponseInputItem | None: @@ -2862,10 +2917,20 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.saved_items.extend(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] async def pop_item(self) -> TResponseInputItem | None: @@ -2908,10 +2973,20 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.saved_items.extend(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] async def pop_item(self) -> TResponseInputItem | None: @@ -4529,15 +4604,17 @@ async def echo_tool(text: str) -> str: }, ] + # The runner passes the run context wrapper to sessions that accept it, and the + # patched mock accepts any kwargs, so expect the wrapper keyword argument too. expected_calls = [ # First call is the initial input - (([expected_items[0]],),), + call([expected_items[0]], wrapper=ANY), # Second call is the first tool call and its result - (([expected_items[1], expected_items[2]],),), + call([expected_items[1], expected_items[2]], wrapper=ANY), # Third call is the second tool call and its result - (([expected_items[3], expected_items[4]],),), + call([expected_items[3], expected_items[4]], wrapper=ANY), # Fourth call is the final output - (([expected_items[5]],),), + call([expected_items[5]], wrapper=ANY), ] assert mock_add_items.call_args_list == expected_calls assert result.final_output == "Summary: Echoed foo and bar" diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index 8ee3a55db4..b94b33005e 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -1263,7 +1263,12 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: for item in items: if isinstance(item, dict): assert "id" not in item, "IDs should be stripped before saving" @@ -1272,7 +1277,12 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: ) self.saved.append(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] async def pop_item(self) -> TResponseInputItem | None: @@ -1916,6 +1926,7 @@ async def save_wrapper( response_id: str | None, reasoning_item_id_policy: str | None = None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> int: observed_counts.append(persisted_count) result = await real_save_resumed( @@ -1925,6 +1936,7 @@ async def save_wrapper( response_id=response_id, reasoning_item_id_policy=reasoning_item_id_policy, store=store, + wrapper=wrapper, ) return int(result) diff --git a/tests/utils/simple_session.py b/tests/utils/simple_session.py index 94bcc97e9e..a9d3721259 100644 --- a/tests/utils/simple_session.py +++ b/tests/utils/simple_session.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import cast +from typing import Any, cast from agents.items import TResponseInputItem from agents.memory.session import Session from agents.memory.session_settings import SessionSettings +from agents.run_context import RunContextWrapper class SimpleListSession(Session): @@ -24,14 +25,24 @@ def __init__( # Mirror saved_items used by some tests for inspection. self.saved_items: list[TResponseInputItem] = self._items - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: if limit is None: return list(self._items) if limit <= 0: return [] return self._items[-limit:] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self._items.extend(items) async def pop_item(self) -> TResponseInputItem | None: @@ -70,7 +81,12 @@ def __init__( super().__init__(session_id=session_id, history=history) self._ignore_ids_for_matching = True - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: sanitized: list[TResponseInputItem] = [] for item in items: if isinstance(item, dict):