diff --git a/examples/memory/file_session.py b/examples/memory/file_session.py index e62dbd167f..abd6d81717 100644 --- a/examples/memory/file_session.py +++ b/examples/memory/file_session.py @@ -15,6 +15,7 @@ from agents.memory.session import Session from agents.memory.session_settings import SessionSettings +from agents.run_context import RunContextWrapper class FileSession(Session): @@ -43,14 +44,24 @@ async def get_session_id(self) -> str: """Return the session id, creating one if needed.""" return await self._ensure_session_id() - async def get_items(self, limit: int | None = None) -> list[Any]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[Any]: session_id = await self._ensure_session_id() items = await self._read_items(session_id) if limit is not None and limit >= 0: return items[-limit:] return items - async def add_items(self, items: list[Any]) -> None: + async def add_items( + self, + items: list[Any], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: if not items: return session_id = await self._ensure_session_id() diff --git a/src/agents/extensions/memory/advanced_sqlite_session.py b/src/agents/extensions/memory/advanced_sqlite_session.py index fcb4743cb3..0a69de0cde 100644 --- a/src/agents/extensions/memory/advanced_sqlite_session.py +++ b/src/agents/extensions/memory/advanced_sqlite_session.py @@ -15,6 +15,7 @@ from ...items import TResponseInputItem from ...memory import SQLiteSession from ...memory.session_settings import SessionSettings, resolve_session_limit +from ...run_context import RunContextWrapper class AdvancedSQLiteSession(SQLiteSession): @@ -122,14 +123,19 @@ def _init_structure_tables(self): conn.commit() - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add items to the session. Args: items: The items to add to the session """ # Add to base table first - await super().add_items(items) + await super().add_items(items, wrapper=wrapper) # Extract structure metadata with precise sequencing if items: @@ -138,12 +144,15 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: async def get_items( self, limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, branch_id: str | None = None, ) -> list[TResponseInputItem]: """Get items from current or specified branch. Args: limit: Maximum number of items to return. If None, uses session_settings.limit. + wrapper: Optional runtime wrapper for the current run context. branch_id: Branch to get items from. If None, uses current branch. Returns: diff --git a/src/agents/extensions/memory/async_sqlite_session.py b/src/agents/extensions/memory/async_sqlite_session.py index 2eef596264..388ad41818 100644 --- a/src/agents/extensions/memory/async_sqlite_session.py +++ b/src/agents/extensions/memory/async_sqlite_session.py @@ -5,13 +5,14 @@ from collections.abc import AsyncIterator from contextlib import asynccontextmanager from pathlib import Path -from typing import cast +from typing import Any, cast import aiosqlite from ...items import TResponseInputItem from ...memory import SessionABC from ...memory.session_settings import SessionSettings +from ...run_context import RunContextWrapper class AsyncSQLiteSession(SessionABC): @@ -102,7 +103,12 @@ async def _locked_connection(self) -> AsyncIterator[aiosqlite.Connection]: conn = await self._get_connection() yield conn - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -150,7 +156,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/extensions/memory/dapr_session.py b/src/agents/extensions/memory/dapr_session.py index ce6bf754a3..e8ea52e900 100644 --- a/src/agents/extensions/memory/dapr_session.py +++ b/src/agents/extensions/memory/dapr_session.py @@ -29,6 +29,8 @@ import time from typing import Any, Final, Literal +from ...run_context import RunContextWrapper + try: from dapr.aio.clients import DaprClient from dapr.clients.grpc._state import Concurrency, Consistency, StateOptions @@ -232,7 +234,12 @@ async def _handle_concurrency_conflict(self, error: Exception, attempt: int) -> # Session protocol implementation # ------------------------------------------------------------------ - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -271,7 +278,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: continue return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/extensions/memory/encrypt_session.py b/src/agents/extensions/memory/encrypt_session.py index d7f2e8edb9..461d65b49f 100644 --- a/src/agents/extensions/memory/encrypt_session.py +++ b/src/agents/extensions/memory/encrypt_session.py @@ -1,196 +1,275 @@ -"""Encrypted Session wrapper for secure conversation storage. - -This module provides transparent encryption for session storage with automatic -expiration of old data. When TTL expires, expired items are silently skipped. - -Usage:: - - from agents.extensions.memory import EncryptedSession, SQLAlchemySession - - # Create underlying session (e.g. SQLAlchemySession) - underlying_session = SQLAlchemySession.from_url( - session_id="user-123", - url="postgresql+asyncpg://app:secret@db.example.com/agents", - create_tables=True, - ) - - # Wrap with encryption and TTL-based expiration - session = EncryptedSession( - session_id="user-123", - underlying_session=underlying_session, - encryption_key="your-encryption-key", - ttl=600, # 10 minutes - ) - - await Runner.run(agent, "Hello", session=session) -""" - -from __future__ import annotations - -import base64 -import json -from typing import Any, cast - -from cryptography.fernet import Fernet, InvalidToken -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.kdf.hkdf import HKDF -from typing_extensions import Literal, TypedDict, TypeGuard - -from ...items import TResponseInputItem -from ...memory.session import SessionABC -from ...memory.session_settings import SessionSettings - - -class EncryptedEnvelope(TypedDict): - """TypedDict for encrypted message envelopes stored in the underlying session.""" - - __enc__: Literal[1] - v: int - kid: str - payload: str - - -def _ensure_fernet_key_bytes(master_key: str) -> bytes: - """ - Accept either a Fernet key (urlsafe-b64, 32 bytes after decode) or a raw string. - Returns raw bytes suitable for HKDF input. - """ - if not master_key: - raise ValueError("encryption_key not set; required for EncryptedSession.") - try: - key_bytes = base64.urlsafe_b64decode(master_key) - if len(key_bytes) == 32: - return key_bytes - except Exception: - pass - return master_key.encode("utf-8") - - -def _derive_session_fernet_key(master_key_bytes: bytes, session_id: str) -> Fernet: - hkdf = HKDF( - algorithm=hashes.SHA256(), - length=32, - salt=session_id.encode("utf-8"), - info=b"agents.session-store.hkdf.v1", - ) - derived = hkdf.derive(master_key_bytes) - return Fernet(base64.urlsafe_b64encode(derived)) - - -def _to_json_bytes(obj: Any) -> bytes: - return json.dumps(obj, ensure_ascii=False, separators=(",", ":"), default=str).encode("utf-8") - - -def _from_json_bytes(data: bytes) -> Any: - return json.loads(data.decode("utf-8")) - - -def _is_encrypted_envelope(item: object) -> TypeGuard[EncryptedEnvelope]: - """Type guard to check if an item is an encrypted envelope.""" - return ( - isinstance(item, dict) - and item.get("__enc__") == 1 - and "payload" in item - and "kid" in item - and "v" in item - ) - - -class EncryptedSession(SessionABC): - """Encrypted wrapper for Session implementations with TTL-based expiration. - - This class wraps any SessionABC implementation to provide transparent - encryption/decryption of stored items using Fernet encryption with - per-session key derivation and automatic expiration of old data. - - When items expire (exceed TTL), they are silently skipped during retrieval. - - Note: Expired tokens are rejected based on the system clock of the application server. - To avoid valid tokens being rejected due to clock drift, ensure all servers in - your environment are synchronized using NTP. - """ - - def __init__( - self, - session_id: str, - underlying_session: SessionABC, - encryption_key: str, - ttl: int = 600, - ): - """ - Args: - session_id: ID for this session - underlying_session: The real session store (e.g. SQLiteSession, SQLAlchemySession) - encryption_key: Master key (Fernet key or raw secret) - ttl: Token time-to-live in seconds (default 10 min) - """ - self.session_id = session_id - self.underlying_session = underlying_session - self.ttl = ttl - - master = _ensure_fernet_key_bytes(encryption_key) - self.cipher = _derive_session_fernet_key(master, session_id) - self._kid = "hkdf-v1" - self._ver = 1 - - def __getattr__(self, name): - return getattr(self.underlying_session, name) - - @property - def session_settings(self) -> SessionSettings | None: - """Get session settings from the underlying session.""" - return self.underlying_session.session_settings - - @session_settings.setter - def session_settings(self, value: SessionSettings | None) -> None: - """Set session settings on the underlying session.""" - self.underlying_session.session_settings = value - - def _wrap(self, item: TResponseInputItem) -> EncryptedEnvelope: - if isinstance(item, dict): - payload = item - elif hasattr(item, "model_dump"): - payload = item.model_dump() - elif hasattr(item, "__dict__"): - payload = item.__dict__ - else: - payload = dict(item) - - token = self.cipher.encrypt(_to_json_bytes(payload)).decode("utf-8") - return {"__enc__": 1, "v": self._ver, "kid": self._kid, "payload": token} - - def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInputItem | None: - if not _is_encrypted_envelope(item): - return cast(TResponseInputItem, item) - - try: - token = item["payload"].encode("utf-8") - plaintext = self.cipher.decrypt(token, ttl=self.ttl) - return cast(TResponseInputItem, _from_json_bytes(plaintext)) - except (InvalidToken, KeyError): - return None - - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: - encrypted_items = await self.underlying_session.get_items(limit) - valid_items: list[TResponseInputItem] = [] - for enc in encrypted_items: - item = self._unwrap(enc) - if item is not None: - valid_items.append(item) - return valid_items - - async def add_items(self, items: list[TResponseInputItem]) -> None: - wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items] - await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped)) - - async def pop_item(self) -> TResponseInputItem | None: - while True: - enc = await self.underlying_session.pop_item() - if not enc: - return None - item = self._unwrap(enc) - if item is not None: - return item - - async def clear_session(self) -> None: - await self.underlying_session.clear_session() +"""Encrypted Session wrapper for secure conversation storage. + +This module provides transparent encryption for session storage with automatic +expiration of old data. When TTL expires, expired items are silently skipped. + +Usage:: + + from agents.extensions.memory import EncryptedSession, SQLAlchemySession + + # Create underlying session (e.g. SQLAlchemySession) + underlying_session = SQLAlchemySession.from_url( + session_id="user-123", + url="postgresql+asyncpg://app:secret@db.example.com/agents", + create_tables=True, + ) + + # Wrap with encryption and TTL-based expiration + session = EncryptedSession( + session_id="user-123", + underlying_session=underlying_session, + encryption_key="your-encryption-key", + ttl=600, # 10 minutes + ) + + await Runner.run(agent, "Hello", session=session) +""" + +from __future__ import annotations + +import base64 +import inspect +import json +from typing import Any, cast + +from cryptography.fernet import Fernet, InvalidToken +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.kdf.hkdf import HKDF +from typing_extensions import Literal, TypedDict, TypeGuard + +from ...items import TResponseInputItem +from ...memory.session import SessionABC +from ...memory.session_settings import SessionSettings +from ...run_context import RunContextWrapper + + +class EncryptedEnvelope(TypedDict): + """TypedDict for encrypted message envelopes stored in the underlying session.""" + + __enc__: Literal[1] + v: int + kid: str + payload: str + + +def _ensure_fernet_key_bytes(master_key: str) -> bytes: + """ + Accept either a Fernet key (urlsafe-b64, 32 bytes after decode) or a raw string. + Returns raw bytes suitable for HKDF input. + """ + if not master_key: + raise ValueError("encryption_key not set; required for EncryptedSession.") + try: + key_bytes = base64.urlsafe_b64decode(master_key) + if len(key_bytes) == 32: + return key_bytes + except Exception: + pass + return master_key.encode("utf-8") + + +def _derive_session_fernet_key(master_key_bytes: bytes, session_id: str) -> Fernet: + hkdf = HKDF( + algorithm=hashes.SHA256(), + length=32, + salt=session_id.encode("utf-8"), + info=b"agents.session-store.hkdf.v1", + ) + derived = hkdf.derive(master_key_bytes) + return Fernet(base64.urlsafe_b64encode(derived)) + + +def _to_json_bytes(obj: Any) -> bytes: + return json.dumps(obj, ensure_ascii=False, separators=(",", ":"), default=str).encode("utf-8") + + +def _from_json_bytes(data: bytes) -> Any: + return json.loads(data.decode("utf-8")) + + +def _is_encrypted_envelope(item: object) -> TypeGuard[EncryptedEnvelope]: + """Type guard to check if an item is an encrypted envelope.""" + return ( + isinstance(item, dict) + and item.get("__enc__") == 1 + and "payload" in item + and "kid" in item + and "v" in item + ) + + +def _method_accepts_wrapper(method: Any) -> bool: + try: + parameters = tuple(inspect.signature(method).parameters.values()) + except (TypeError, ValueError): + return False + + return any( + parameter.kind is inspect.Parameter.VAR_KEYWORD or parameter.name == "wrapper" + for parameter in parameters + ) + + +def _method_accepts_limit(method: Any) -> bool: + try: + parameters = tuple(inspect.signature(method).parameters.values()) + except (TypeError, ValueError): + return False + + return any( + ( + parameter.kind + in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + and parameter.name == "limit" + ) + or parameter.kind is inspect.Parameter.VAR_KEYWORD + for parameter in parameters + ) + + +async def _delegate_get_items( + session: SessionABC, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, +) -> list[TResponseInputItem]: + accepts_wrapper = wrapper is not None and _method_accepts_wrapper(session.get_items) + accepts_limit = _method_accepts_limit(session.get_items) + + if limit is None: + if accepts_wrapper: + return await session.get_items(wrapper=wrapper) + return await session.get_items() + + if accepts_limit: + if accepts_wrapper: + return await session.get_items(limit=limit, wrapper=wrapper) + return await session.get_items(limit=limit) + + if accepts_wrapper: + items = await session.get_items(wrapper=wrapper) + else: + items = await session.get_items() + + return items[-limit:] if limit > 0 else [] + + +class EncryptedSession(SessionABC): + """Encrypted wrapper for Session implementations with TTL-based expiration. + + This class wraps any SessionABC implementation to provide transparent + encryption/decryption of stored items using Fernet encryption with + per-session key derivation and automatic expiration of old data. + + When items expire (exceed TTL), they are silently skipped during retrieval. + + Note: Expired tokens are rejected based on the system clock of the application server. + To avoid valid tokens being rejected due to clock drift, ensure all servers in + your environment are synchronized using NTP. + """ + + def __init__( + self, + session_id: str, + underlying_session: SessionABC, + encryption_key: str, + ttl: int = 600, + ): + """ + Args: + session_id: ID for this session + underlying_session: The real session store (e.g. SQLiteSession, SQLAlchemySession) + encryption_key: Master key (Fernet key or raw secret) + ttl: Token time-to-live in seconds (default 10 min) + """ + self.session_id = session_id + self.underlying_session = underlying_session + self.ttl = ttl + + master = _ensure_fernet_key_bytes(encryption_key) + self.cipher = _derive_session_fernet_key(master, session_id) + self._kid = "hkdf-v1" + self._ver = 1 + + def __getattr__(self, name): + return getattr(self.underlying_session, name) + + @property + def session_settings(self) -> SessionSettings | None: + """Get session settings from the underlying session.""" + return self.underlying_session.session_settings + + @session_settings.setter + def session_settings(self, value: SessionSettings | None) -> None: + """Set session settings on the underlying session.""" + self.underlying_session.session_settings = value + + def _wrap(self, item: TResponseInputItem) -> EncryptedEnvelope: + if isinstance(item, dict): + payload = item + elif hasattr(item, "model_dump"): + payload = item.model_dump() + elif hasattr(item, "__dict__"): + payload = item.__dict__ + else: + payload = dict(item) + + token = self.cipher.encrypt(_to_json_bytes(payload)).decode("utf-8") + return {"__enc__": 1, "v": self._ver, "kid": self._kid, "payload": token} + + def _unwrap(self, item: TResponseInputItem | EncryptedEnvelope) -> TResponseInputItem | None: + if not _is_encrypted_envelope(item): + return cast(TResponseInputItem, item) + + try: + token = item["payload"].encode("utf-8") + plaintext = self.cipher.decrypt(token, ttl=self.ttl) + return cast(TResponseInputItem, _from_json_bytes(plaintext)) + except (InvalidToken, KeyError): + return None + + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + encrypted_items = await _delegate_get_items( + self.underlying_session, + limit=limit, + wrapper=wrapper, + ) + valid_items: list[TResponseInputItem] = [] + for enc in encrypted_items: + item = self._unwrap(enc) + if item is not None: + valid_items.append(item) + return valid_items + + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items] + wrapped_items = cast(list[TResponseInputItem], wrapped) + if wrapper is not None and _method_accepts_wrapper(self.underlying_session.add_items): + await self.underlying_session.add_items( + wrapped_items, + wrapper=wrapper, + ) + else: + await self.underlying_session.add_items(wrapped_items) + + async def pop_item(self) -> TResponseInputItem | None: + while True: + enc = await self.underlying_session.pop_item() + if not enc: + return None + item = self._unwrap(enc) + if item is not None: + return item + + async def clear_session(self) -> None: + await self.underlying_session.clear_session() diff --git a/src/agents/extensions/memory/redis_session.py b/src/agents/extensions/memory/redis_session.py index 1eee549e11..22519afde5 100644 --- a/src/agents/extensions/memory/redis_session.py +++ b/src/agents/extensions/memory/redis_session.py @@ -26,6 +26,8 @@ import time from typing import Any +from ...run_context import RunContextWrapper + try: import redis.asyncio as redis from redis.asyncio import Redis @@ -140,7 +142,12 @@ async def _set_ttl_if_configured(self, *keys: str) -> None: # Session protocol implementation # ------------------------------------------------------------------ - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -179,7 +186,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/extensions/memory/sqlalchemy_session.py b/src/agents/extensions/memory/sqlalchemy_session.py index 8bfaa95769..a00de58632 100644 --- a/src/agents/extensions/memory/sqlalchemy_session.py +++ b/src/agents/extensions/memory/sqlalchemy_session.py @@ -48,6 +48,7 @@ from ...items import TResponseInputItem from ...memory.session import SessionABC from ...memory.session_settings import SessionSettings, resolve_session_limit +from ...run_context import RunContextWrapper class SQLAlchemySession(SessionABC): @@ -187,7 +188,12 @@ async def _ensure_tables(self) -> None: await conn.run_sync(self._metadata.create_all) self._create_tables = False # Only create once - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -239,7 +245,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: continue return items - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/memory/openai_conversations_session.py b/src/agents/memory/openai_conversations_session.py index 4d4fbaf635..e05caf88de 100644 --- a/src/agents/memory/openai_conversations_session.py +++ b/src/agents/memory/openai_conversations_session.py @@ -1,10 +1,13 @@ from __future__ import annotations +from typing import Any + from openai import AsyncOpenAI from agents.models._openai_shared import get_default_openai_client from ..items import TResponseInputItem +from ..run_context import RunContextWrapper from .session import SessionABC from .session_settings import SessionSettings, resolve_session_limit @@ -70,7 +73,12 @@ async def _get_session_id(self) -> str: async def _clear_session_id(self) -> None: self._session_id = None - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: session_id = await self._get_session_id() session_limit = resolve_session_limit(limit, self.session_settings) @@ -97,7 +105,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: return all_items # type: ignore - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: session_id = await self._get_session_id() if not items: return diff --git a/src/agents/memory/openai_responses_compaction_session.py b/src/agents/memory/openai_responses_compaction_session.py index e2148f4868..305eddd067 100644 --- a/src/agents/memory/openai_responses_compaction_session.py +++ b/src/agents/memory/openai_responses_compaction_session.py @@ -1,11 +1,13 @@ from __future__ import annotations +import inspect import logging from typing import TYPE_CHECKING, Any, Callable, Literal from openai import AsyncOpenAI from ..models._openai_shared import get_default_openai_client +from ..run_context import RunContextWrapper from .openai_conversations_session import OpenAIConversationsSession from .session import ( OpenAIResponsesCompactionArgs, @@ -24,6 +26,61 @@ OpenAIResponsesCompactionMode = Literal["previous_response_id", "input", "auto"] +def _method_signature(method: Any) -> tuple[inspect.Parameter, ...]: + try: + return tuple(inspect.signature(method).parameters.values()) + except (TypeError, ValueError): + return () + + +def _method_accepts_wrapper(method: Any) -> bool: + parameters = _method_signature(method) + return any( + parameter.kind is inspect.Parameter.VAR_KEYWORD or parameter.name == "wrapper" + for parameter in parameters + ) + + +def _method_accepts_limit(method: Any) -> bool: + parameters = _method_signature(method) + return any( + ( + parameter.kind + in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + and parameter.name == "limit" + ) + or parameter.kind is inspect.Parameter.VAR_KEYWORD + for parameter in parameters + ) + + +async def _delegate_get_items( + session: Session, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, +) -> list[TResponseInputItem]: + accepts_wrapper = wrapper is not None and _method_accepts_wrapper(session.get_items) + accepts_limit = _method_accepts_limit(session.get_items) + + if limit is None: + if accepts_wrapper: + return await session.get_items(wrapper=wrapper) + return await session.get_items() + + if accepts_limit: + if accepts_wrapper: + return await session.get_items(limit=limit, wrapper=wrapper) + return await session.get_items(limit=limit) + + if accepts_wrapper: + items = await session.get_items(wrapper=wrapper) + else: + items = await session.get_items() + + return items[-limit:] if limit > 0 else [] + + def select_compaction_candidate_items( items: list[TResponseInputItem], ) -> list[TResponseInputItem]: @@ -236,8 +293,17 @@ async def run_compaction(self, args: OpenAIResponsesCompactionArgs | None = None f"candidates={len(self._compaction_candidate_items)})" ) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: - return await self.underlying_session.get_items(limit) + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + return await _delegate_get_items( + self.underlying_session, + limit=limit, + wrapper=wrapper, + ) async def _defer_compaction(self, response_id: str, store: bool | None = None) -> None: if self._deferred_response_id is not None: @@ -265,8 +331,16 @@ def _get_deferred_compaction_response_id(self) -> str | None: def _clear_deferred_compaction(self) -> None: self._deferred_response_id = None - async def add_items(self, items: list[TResponseInputItem]) -> None: - await self.underlying_session.add_items(items) + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + if wrapper is not None and _method_accepts_wrapper(self.underlying_session.add_items): + await self.underlying_session.add_items(items, wrapper=wrapper) + else: + await self.underlying_session.add_items(items) if self._compaction_candidate_items is not None: new_candidates = select_compaction_candidate_items(items) if new_candidates: diff --git a/src/agents/memory/session.py b/src/agents/memory/session.py index 85a65a1690..d7ccd8d621 100644 --- a/src/agents/memory/session.py +++ b/src/agents/memory/session.py @@ -1,12 +1,13 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Literal, Protocol, runtime_checkable +from typing import TYPE_CHECKING, Any, Literal, Protocol, runtime_checkable from typing_extensions import TypedDict, TypeGuard if TYPE_CHECKING: from ..items import TResponseInputItem + from ..run_context import RunContextWrapper from .session_settings import SessionSettings @@ -21,23 +22,37 @@ class Session(Protocol): session_id: str session_settings: SessionSettings | None = None - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: limit: Maximum number of items to retrieve. If None, retrieves all items. When specified, returns the latest N items in chronological order. + wrapper: Optional runtime wrapper for the current run context. Implementations may + ignore this parameter. Returns: List of input items representing the conversation history """ ... - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: items: List of input items to add to the history + wrapper: Optional runtime wrapper for the current run context. Implementations may + ignore this parameter. """ ... @@ -68,12 +83,19 @@ class SessionABC(ABC): session_settings: SessionSettings | None = None @abstractmethod - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: limit: Maximum number of items to retrieve. If None, retrieves all items. When specified, returns the latest N items in chronological order. + wrapper: Optional runtime wrapper for the current run context. Implementations may + ignore this parameter. Returns: List of input items representing the conversation history @@ -81,11 +103,18 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: ... @abstractmethod - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: items: List of input items to add to the history + wrapper: Optional runtime wrapper for the current run context. Implementations may + ignore this parameter. """ ... diff --git a/src/agents/memory/sqlite_session.py b/src/agents/memory/sqlite_session.py index 92c9630c9b..5877d7fa71 100644 --- a/src/agents/memory/sqlite_session.py +++ b/src/agents/memory/sqlite_session.py @@ -5,8 +5,10 @@ import sqlite3 import threading from pathlib import Path +from typing import Any from ..items import TResponseInputItem +from ..run_context import RunContextWrapper from .session import SessionABC from .session_settings import SessionSettings, resolve_session_limit @@ -114,7 +116,12 @@ def _init_db_for_connection(self, conn: sqlite3.Connection) -> None: conn.commit() - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: """Retrieve the conversation history for this session. Args: @@ -170,7 +177,12 @@ def _get_items_sync(): return await asyncio.to_thread(_get_items_sync) - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: """Add new items to the conversation history. Args: diff --git a/src/agents/run.py b/src/agents/run.py index 047d454d35..e05b325837 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -454,6 +454,8 @@ async def run( max_turns = run_state._max_turns else: + context_wrapper = ensure_context_wrapper(context) + context = context_wrapper.context raw_input = cast(Union[str, list[TResponseInputItem]], input) original_user_input = raw_input @@ -478,6 +480,7 @@ async def run( run_config.session_settings, include_history_in_prepared_input=False, preserve_dropped_new_items=True, + wrapper=context_wrapper, ) original_input_for_state = raw_input session_input_items_for_persistence = [] @@ -490,6 +493,7 @@ async def run( session, run_config.session_input_callback, run_config.session_settings, + wrapper=context_wrapper, ) original_input_for_state = prepared_input @@ -569,7 +573,6 @@ async def run( generated_items = [] session_items = [] model_responses = [] - context_wrapper = ensure_context_wrapper(context) set_agent_tool_state_scope(context_wrapper, None) run_state = RunState( context=context_wrapper, @@ -631,6 +634,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: [], run_state, store=store_setting, + wrapper=context_wrapper, ) session_input_items_for_persistence = [] @@ -700,6 +704,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: run_state._reasoning_item_id_policy ), store=store_setting, + wrapper=context_wrapper, ) ) @@ -818,6 +823,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: run_state, response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) result._original_input = copy_input_items(original_input) return finalize_conversation_tracking( @@ -955,6 +961,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: run_state, response_id=None, store=store_setting, + wrapper=context_wrapper, ) result._original_input = copy_input_items(original_input) return finalize_conversation_tracking( @@ -1159,6 +1166,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: run_state._reasoning_item_id_policy ), store=store_setting, + wrapper=context_wrapper, ) run_state._current_turn_persisted_item_count += saved_count else: @@ -1169,6 +1177,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: run_state, response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) # After the first resumed turn, treat subsequent turns as fresh @@ -1223,6 +1232,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: items=session_items_for_turn(turn_result), response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) result._original_input = copy_input_items(original_input) return finalize_conversation_tracking( @@ -1246,6 +1256,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: run_state, response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) append_model_response_if_new( model_responses, turn_result.model_response @@ -1311,6 +1322,7 @@ def _with_reasoning_item_id_policy(result: RunResult) -> RunResult: items=session_items_for_turn(turn_result), response_id=turn_result.model_response.response_id, store=store_setting, + wrapper=context_wrapper, ) continue else: diff --git a/src/agents/run_internal/agent_runner_helpers.py b/src/agents/run_internal/agent_runner_helpers.py index 776e406703..e6a133271e 100644 --- a/src/agents/run_internal/agent_runner_helpers.py +++ b/src/agents/run_internal/agent_runner_helpers.py @@ -322,6 +322,7 @@ async def save_turn_items_if_needed( items: list[RunItem], response_id: str | None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> None: """Persist turn items when persistence is enabled and guardrails allow it.""" if not session_persistence_enabled: @@ -337,6 +338,7 @@ async def save_turn_items_if_needed( run_state, response_id=response_id, store=store, + wrapper=wrapper, ) diff --git a/src/agents/run_internal/run_loop.py b/src/agents/run_internal/run_loop.py index 3d21d89fda..a61acc30bc 100644 --- a/src/agents/run_internal/run_loop.py +++ b/src/agents/run_internal/run_loop.py @@ -7,6 +7,7 @@ import asyncio import dataclasses as _dc +import inspect import json from collections.abc import Awaitable, Callable, Mapping from typing import Any, TypeVar, cast @@ -264,6 +265,18 @@ def _complete_stream_interruption( streamed_result._event_queue.put_nowait(QueueCompleteSentinel()) +def _call_supports_wrapper(callable_obj: Any) -> bool: + try: + parameters = tuple(inspect.signature(callable_obj).parameters.values()) + except (TypeError, ValueError): + return False + + return any( + parameter.kind is inspect.Parameter.VAR_KEYWORD or parameter.name == "wrapper" + for parameter in parameters + ) + + async def _save_resumed_stream_items( *, session: Session | None, @@ -273,6 +286,7 @@ async def _save_resumed_stream_items( items: list[RunItem], response_id: str | None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> None: if not await _should_persist_stream_items( session=session, @@ -280,13 +294,18 @@ async def _save_resumed_stream_items( streamed_result=streamed_result, ): return + save_kwargs: dict[str, Any] = { + "session": session, + "items": items, + "persisted_count": streamed_result._current_turn_persisted_item_count, + "response_id": response_id, + "reasoning_item_id_policy": streamed_result._reasoning_item_id_policy, + "store": store, + } + if wrapper is not None and _call_supports_wrapper(save_resumed_turn_items): + save_kwargs["wrapper"] = wrapper streamed_result._current_turn_persisted_item_count = await save_resumed_turn_items( - session=session, - items=items, - persisted_count=streamed_result._current_turn_persisted_item_count, - response_id=response_id, - reasoning_item_id_policy=streamed_result._reasoning_item_id_policy, - store=store, + **save_kwargs, ) if run_state is not None: run_state._current_turn_persisted_item_count = ( @@ -304,6 +323,7 @@ async def _save_stream_items( response_id: str | None, update_persisted_count: bool, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> None: if not await _should_persist_stream_items( session=session, @@ -311,13 +331,18 @@ async def _save_stream_items( streamed_result=streamed_result, ): return + save_kwargs: dict[str, Any] = { + "response_id": response_id, + "store": store, + } + if wrapper is not None and _call_supports_wrapper(save_result_to_session): + save_kwargs["wrapper"] = wrapper await save_result_to_session( session, [], list(items), run_state, - response_id=response_id, - store=store, + **save_kwargs, ) if update_persisted_count and streamed_result._state is not None: streamed_result._current_turn_persisted_item_count = ( @@ -545,6 +570,7 @@ def _sync_conversation_tracking_from_tracker() -> None: run_config.session_settings, include_history_in_prepared_input=not server_manages_conversation, preserve_dropped_new_items=True, + wrapper=context_wrapper, ) streamed_result.input = prepared_input streamed_result._original_input = copy_input_items(prepared_input) @@ -566,6 +592,7 @@ async def _save_resumed_items( items=items, response_id=response_id, store=store_setting, + wrapper=context_wrapper, ) async def _save_stream_items_with_count( @@ -580,6 +607,7 @@ async def _save_stream_items_with_count( response_id=response_id, update_persisted_count=True, store=store_setting, + wrapper=context_wrapper, ) async def _save_stream_items_without_count( @@ -594,6 +622,7 @@ async def _save_stream_items_without_count( response_id=response_id, update_persisted_count=False, store=store_setting, + wrapper=context_wrapper, ) try: @@ -1237,7 +1266,13 @@ def _tool_search_fingerprint(raw_item: Any) -> str: ) ] if input_items_to_save: - await save_result_to_session(session, input_items_to_save, [], streamed_result._state) + await save_result_to_session( + session, + input_items_to_save, + [], + streamed_result._state, + wrapper=context_wrapper, + ) previous_response_id = ( server_conversation_tracker.previous_response_id diff --git a/src/agents/run_internal/session_persistence.py b/src/agents/run_internal/session_persistence.py index d63c5f0526..a43995155b 100644 --- a/src/agents/run_internal/session_persistence.py +++ b/src/agents/run_internal/session_persistence.py @@ -23,6 +23,7 @@ is_openai_responses_compaction_aware_session, ) from ..memory.openai_conversations_session import OpenAIConversationsSession +from ..run_context import RunContextWrapper from ..run_state import RunState from .items import ( ReasoningItemIdPolicy, @@ -50,6 +51,70 @@ ] +def _session_method_signature(method: Any) -> tuple[inspect.Parameter, ...]: + try: + return tuple(inspect.signature(method).parameters.values()) + except (TypeError, ValueError): + return () + + +def _session_method_accepts_wrapper(method: Any) -> bool: + parameters = _session_method_signature(method) + return any( + parameter.kind is inspect.Parameter.VAR_KEYWORD or parameter.name == "wrapper" + for parameter in parameters + ) + + +def _session_method_accepts_limit(method: Any) -> bool: + parameters = _session_method_signature(method) + return any( + parameter.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + and parameter.name == "limit" + or parameter.kind is inspect.Parameter.VAR_KEYWORD + for parameter in parameters + ) + + +async def _session_get_items( + session: Session, + limit: int | None = None, + *, + wrapper: RunContextWrapper[Any] | None = None, +) -> list[TResponseInputItem]: + accepts_wrapper = wrapper is not None and _session_method_accepts_wrapper(session.get_items) + accepts_limit = _session_method_accepts_limit(session.get_items) + + if limit is None: + if accepts_wrapper: + return await session.get_items(wrapper=wrapper) + return await session.get_items() + + if accepts_limit: + if accepts_wrapper: + return await session.get_items(limit=limit, wrapper=wrapper) + return await session.get_items(limit=limit) + + if accepts_wrapper: + items = await session.get_items(wrapper=wrapper) + else: + items = await session.get_items() + + return items[-limit:] if limit > 0 else [] + + +async def _session_add_items( + session: Session, + items: list[TResponseInputItem], + *, + wrapper: RunContextWrapper[Any] | None = None, +) -> None: + if wrapper is not None and _session_method_accepts_wrapper(session.add_items): + await session.add_items(items, wrapper=wrapper) + return + await session.add_items(items) + + async def prepare_input_with_session( input: str | list[TResponseInputItem], session: Session | None, @@ -58,6 +123,7 @@ async def prepare_input_with_session( *, include_history_in_prepared_input: bool = True, preserve_dropped_new_items: bool = False, + wrapper: RunContextWrapper[Any] | None = None, ) -> tuple[str | list[TResponseInputItem], list[TResponseInputItem]]: """Prepare model input from session history plus the new turn input. @@ -82,9 +148,13 @@ async def prepare_input_with_session( resolved_settings = resolved_settings.resolve(session_settings) if resolved_settings.limit is not None: - history = await session.get_items(limit=resolved_settings.limit) + history = await _session_get_items( + session, + limit=resolved_settings.limit, + wrapper=wrapper, + ) else: - history = await session.get_items() + history = await _session_get_items(session, wrapper=wrapper) converted_history = [ensure_input_item_format(item) for item in history] new_input_list = [ @@ -233,6 +303,7 @@ async def save_result_to_session( response_id: str | None = None, reasoning_item_id_policy: ReasoningItemIdPolicy | None = None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> int: """ Persist a turn to the session store, keeping track of what was already saved so retries @@ -316,7 +387,7 @@ async def save_result_to_session( run_state._current_turn_persisted_item_count = already_persisted + saved_run_items_count return saved_run_items_count - await session.add_items(items_to_save) + await _session_add_items(session, items_to_save, wrapper=wrapper) if run_state: run_state._current_turn_persisted_item_count = already_persisted + saved_run_items_count @@ -367,6 +438,7 @@ async def save_resumed_turn_items( response_id: str | None, reasoning_item_id_policy: ReasoningItemIdPolicy | None = None, store: bool | None = None, + wrapper: RunContextWrapper[Any] | None = None, ) -> int: """Persist resumed turn items and return the updated persisted count.""" if session is None or not items: @@ -379,6 +451,7 @@ async def save_resumed_turn_items( response_id=response_id, reasoning_item_id_policy=reasoning_item_id_policy, store=store, + wrapper=wrapper, ) return persisted_count + saved_count @@ -475,7 +548,7 @@ async def rewind_session_items( return try: - latest_items = await session.get_items(limit=1) + latest_items = await _session_get_items(session, limit=1) except Exception as exc: logger.debug("Failed to peek session items while rewinding: %s", exc) return @@ -523,7 +596,7 @@ async def wait_for_session_cleanup( for attempt in range(max_attempts): try: - tail_items = await session.get_items(limit=window) + tail_items = await _session_get_items(session, limit=window) except Exception as exc: logger.debug("Failed to verify session cleanup (attempt %d): %s", attempt + 1, exc) await asyncio.sleep(0.1 * (attempt + 1)) diff --git a/tests/extensions/memory/test_advanced_sqlite_wrapper_compat.py b/tests/extensions/memory/test_advanced_sqlite_wrapper_compat.py new file mode 100644 index 0000000000..1aad6ebfb7 --- /dev/null +++ b/tests/extensions/memory/test_advanced_sqlite_wrapper_compat.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import tempfile +from pathlib import Path + +import pytest + +from agents.extensions.memory.advanced_sqlite_session import AdvancedSQLiteSession + +pytestmark = pytest.mark.asyncio + + +async def test_advanced_sqlite_get_items_branch_id_kwarg() -> None: + with tempfile.TemporaryDirectory() as temp_dir: + db_path = Path(temp_dir) / "advanced.db" + session = AdvancedSQLiteSession(session_id="test", db_path=db_path, create_tables=True) + + await session.add_items( + [ + {"role": "user", "content": "main message"}, + ] + ) + branch_id = await session.create_branch_from_turn(1, "branch-a") + assert branch_id == "branch-a" + await session.add_items( + [ + {"role": "user", "content": "branch message"}, + ] + ) + await session.switch_to_branch("main") + + branch_items = await session.get_items(50, branch_id="branch-a") + contents = [item.get("content") for item in branch_items if isinstance(item, dict)] + + assert "branch message" in contents + assert "main message" not in contents diff --git a/tests/extensions/memory/test_encrypt_session.py b/tests/extensions/memory/test_encrypt_session.py index ac2a27da6b..7f3124638d 100644 --- a/tests/extensions/memory/test_encrypt_session.py +++ b/tests/extensions/memory/test_encrypt_session.py @@ -7,10 +7,14 @@ pytest.importorskip("cryptography") # Skip tests if cryptography is not installed +from typing import cast + from cryptography.fernet import Fernet from agents import Agent, Runner, SQLiteSession, TResponseInputItem from agents.extensions.memory.encrypt_session import EncryptedSession +from agents.memory.session import SessionABC +from agents.run_context import RunContextWrapper from tests.fake_model import FakeModel from tests.test_responses import get_text_message @@ -111,6 +115,86 @@ async def test_encrypted_session_with_runner( underlying_session.close() +async def test_encrypted_session_preserves_wrapper_only_underlying_with_limit( + encryption_key: str, +): + class WrapperOnlyUnderlyingSession: + def __init__(self) -> None: + self.session_id = "test_session" + self.items: list[TResponseInputItem] = [] + self.get_wrappers: list[object | None] = [] + + async def get_items(self, wrapper: object = None) -> list[TResponseInputItem]: + self.get_wrappers.append(wrapper) + return list(self.items) + + async def add_items(self, items: list[TResponseInputItem]) -> None: + self.items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + return None + + async def clear_session(self) -> None: + self.items.clear() + + underlying = WrapperOnlyUnderlyingSession() + underlying.items = [ + {"role": "user", "content": "one"}, + {"role": "assistant", "content": "two"}, + ] + session = EncryptedSession( + session_id="test_session", + underlying_session=cast(SessionABC, underlying), + encryption_key=encryption_key, + ) + + wrapper = RunContextWrapper(context={"request_id": "encrypt"}) + items = await session.get_items(limit=1, wrapper=wrapper) + + assert items[-1].get("content") == "two" + assert underlying.get_wrappers == [wrapper] + + +async def test_encrypted_session_preserves_legacy_underlying_signatures( + agent: Agent, + encryption_key: str, +): + class LegacyUnderlyingSession: + def __init__(self) -> None: + self.session_id = "test_session" + self.items: list[TResponseInputItem] = [] + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + if limit is None: + return list(self.items) + return list(self.items[-limit:]) if limit > 0 else [] + + async def add_items(self, items: list[TResponseInputItem]) -> None: + self.items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + if not self.items: + return None + return self.items.pop() + + async def clear_session(self) -> None: + self.items.clear() + + legacy_underlying = LegacyUnderlyingSession() + session = EncryptedSession( + session_id="test_session", + underlying_session=cast(SessionABC, legacy_underlying), + encryption_key=encryption_key, + ) + + assert isinstance(agent.model, FakeModel) + agent.model.set_next_output([get_text_message("Hello")]) + result = await Runner.run(agent, "Hi there", session=session) + + assert result.final_output == "Hello" + assert legacy_underlying.items + + async def test_encrypted_session_pop_item(encryption_key: str, underlying_session: SQLiteSession): """Test pop_item functionality.""" session = EncryptedSession( diff --git a/tests/memory/test_openai_responses_compaction_session.py b/tests/memory/test_openai_responses_compaction_session.py index 7af406a602..6ac01e72ce 100644 --- a/tests/memory/test_openai_responses_compaction_session.py +++ b/tests/memory/test_openai_responses_compaction_session.py @@ -19,6 +19,7 @@ is_openai_model_name, select_compaction_candidate_items, ) +from agents.run_context import RunContextWrapper from tests.fake_model import FakeModel from tests.test_responses import get_function_tool, get_function_tool_call, get_text_message from tests.utils.simple_session import SimpleListSession @@ -107,6 +108,83 @@ def test_init_accepts_valid_model(self) -> None: ) assert session.model == "gpt-4.1" + @pytest.mark.asyncio + async def test_get_items_preserves_legacy_wrapper_only_delegate_shape(self) -> None: + class WrapperOnlySession: + session_id = "test-session" + + def __init__(self) -> None: + self.calls: list[tuple[int | None, Any]] = [] + + async def get_items( + self, + wrapper: Any = None, + ) -> list[TResponseInputItem]: + self.calls.append((None, wrapper)) + return [] + + async def add_items(self, items: list[TResponseInputItem]) -> None: + return None + + async def pop_item(self) -> TResponseInputItem | None: + return None + + async def clear_session(self) -> None: + return None + + underlying = WrapperOnlySession() + session = OpenAIResponsesCompactionSession( + session_id="test", + underlying_session=cast(Session, underlying), + ) + + wrapper = RunContextWrapper(context={"request_id": "abc"}) + items = await session.get_items(wrapper=wrapper) + + assert items == [] + assert underlying.calls == [(None, wrapper)] + + @pytest.mark.asyncio + async def test_get_items_with_limit_preserves_wrapper_only_delegate_shape(self) -> None: + class WrapperOnlySession: + session_id = "test-session" + + def __init__(self) -> None: + self.calls: list[tuple[int | None, Any]] = [] + self.items: list[TResponseInputItem] = [ + cast(TResponseInputItem, {"role": "user", "content": "one"}), + cast(TResponseInputItem, {"role": "assistant", "content": "two"}), + ] + + async def get_items( + self, + wrapper: Any = None, + ) -> list[TResponseInputItem]: + self.calls.append((None, wrapper)) + return list(self.items) + + async def add_items(self, items: list[TResponseInputItem]) -> None: + return None + + async def pop_item(self) -> TResponseInputItem | None: + return None + + async def clear_session(self) -> None: + return None + + underlying = WrapperOnlySession() + session = OpenAIResponsesCompactionSession( + session_id="test", + underlying_session=cast(Session, underlying), + ) + + wrapper = RunContextWrapper(context={"request_id": "abc"}) + items = await session.get_items(limit=1, wrapper=wrapper) + + assert len(items) == 1 + assert items[0].get("content") == "two" + assert underlying.calls == [(None, wrapper)] + @pytest.mark.asyncio async def test_add_items_delegates(self) -> None: mock_session = self.create_mock_session() diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index a09dccc382..6fb520fcca 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -319,10 +319,18 @@ class DummySession(Session): session_id = "sess_123" session_settings = SessionSettings() - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: return None async def pop_item(self) -> TResponseInputItem | None: diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index 8b07297167..514450d50d 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -1997,7 +1997,11 @@ def __init__(self) -> None: super().__init__() self.get_items_calls = 0 - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: self.get_items_calls += 1 if self.get_items_calls == 1: raise RuntimeError("temporary failure") @@ -2279,10 +2283,18 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self.saved_items.extend(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] async def pop_item(self) -> TResponseInputItem | None: @@ -3672,17 +3684,16 @@ async def echo_tool(text: str) -> str: }, ] - expected_calls = [ - # First call is the initial input - (([expected_items[0]],),), - # Second call is the first tool call and its result - (([expected_items[1], expected_items[2]],),), - # Third call is the second tool call and its result - (([expected_items[3], expected_items[4]],),), - # Fourth call is the final output - (([expected_items[5]],),), + expected_item_batches = [ + [expected_items[0]], + [expected_items[1], expected_items[2]], + [expected_items[3], expected_items[4]], + [expected_items[5]], ] - assert mock_add_items.call_args_list == expected_calls + assert len(mock_add_items.call_args_list) == len(expected_item_batches) + paired_calls = zip(mock_add_items.call_args_list, expected_item_batches) + for actual_call, expected_batch in paired_calls: + assert actual_call.args == (expected_batch,) assert result.final_output == "Summary: Echoed foo and bar" assert (await session.get_items()) == expected_items diff --git a/tests/test_agent_runner_streamed.py b/tests/test_agent_runner_streamed.py index 0e729fed37..0c336d892f 100644 --- a/tests/test_agent_runner_streamed.py +++ b/tests/test_agent_runner_streamed.py @@ -1126,7 +1126,11 @@ def __init__(self) -> None: async def _get_session_id(self) -> str: return "conv_test" - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: for item in items: if isinstance(item, dict): assert "id" not in item, "IDs should be stripped before saving" @@ -1135,7 +1139,11 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: ) self.saved.append(items) - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: return [] async def pop_item(self) -> TResponseInputItem | None: diff --git a/tests/test_session.py b/tests/test_session.py index aaa80ec7aa..f750ae5070 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -3,10 +3,15 @@ import asyncio import tempfile from pathlib import Path +from typing import Any, cast import pytest from agents import Agent, RunConfig, Runner, SQLiteSession, TResponseInputItem +from agents.memory.session import SessionABC +from agents.memory.session_settings import SessionSettings +from agents.run_context import RunContextWrapper +from agents.run_internal.session_persistence import prepare_input_with_session from .fake_model import FakeModel from .test_responses import get_text_message @@ -49,6 +54,147 @@ async def run_agent_async(runner_method: str, agent, input_data, **kwargs): raise ValueError(f"Unknown runner method: {runner_method}") +class WrapperAwareSession(SessionABC): + session_id = "wrapper-aware" + + def __init__(self) -> None: + self.items: list[TResponseInputItem] = [] + self.get_wrappers: list[RunContextWrapper[Any] | None] = [] + self.add_wrappers: list[RunContextWrapper[Any] | None] = [] + + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + self.get_wrappers.append(wrapper) + if limit is None: + return list(self.items) + return list(self.items[-limit:]) if limit > 0 else [] + + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + self.add_wrappers.append(wrapper) + self.items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + return self.items.pop() if self.items else None + + async def clear_session(self) -> None: + self.items.clear() + + +class LegacyCompatibleSession: + session_id = "legacy-compatible" + + def __init__(self) -> None: + self.items: list[TResponseInputItem] = [] + self.get_call_count = 0 + self.add_call_count = 0 + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + self.get_call_count += 1 + if limit is None: + return list(self.items) + return list(self.items[-limit:]) if limit > 0 else [] + + async def add_items(self, items: list[TResponseInputItem]) -> None: + self.add_call_count += 1 + self.items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + return self.items.pop() if self.items else None + + async def clear_session(self) -> None: + self.items.clear() + + +class LegacyNoLimitKeywordSession: + session_id = "legacy-no-limit-keyword" + + def __init__(self) -> None: + self.items: list[TResponseInputItem] = [] + self.get_call_count = 0 + self.add_call_count = 0 + + async def get_items(self) -> list[TResponseInputItem]: + self.get_call_count += 1 + return list(self.items) + + async def add_items(self, items: list[TResponseInputItem]) -> None: + self.add_call_count += 1 + self.items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + return self.items.pop() if self.items else None + + async def clear_session(self) -> None: + self.items.clear() + + +class DefaultLimitedSession: + session_id = "default-limited" + session_settings: SessionSettings | None = None + + def __init__(self) -> None: + self.items: list[TResponseInputItem] = [] + self.get_call_count = 0 + + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + self.get_call_count += 1 + effective_limit = 1 if limit is None else limit + return list(self.items[-effective_limit:]) if effective_limit > 0 else [] + + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + self.items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + return self.items.pop() if self.items else None + + async def clear_session(self) -> None: + self.items.clear() + + +class WrapperOnlySession: + session_id = "wrapper-only" + session_settings: SessionSettings | None = None + + def __init__(self) -> None: + self.items: list[TResponseInputItem] = [] + self.get_wrappers: list[RunContextWrapper[Any] | None] = [] + + async def get_items( + self, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: + self.get_wrappers.append(wrapper) + return list(self.items) + + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: + self.items.extend(items) + + async def pop_item(self) -> TResponseInputItem | None: + return self.items.pop() if self.items else None + + async def clear_session(self) -> None: + self.items.clear() + + # Parametrized tests for different runner methods @pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) @pytest.mark.asyncio @@ -142,6 +288,123 @@ async def test_session_memory_disabled_parametrized(runner_method): assert len(last_input) == 1 # Should only have the current message +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_session_history_paths_receive_context_wrapper_when_supported(runner_method): + session = WrapperAwareSession() + model = FakeModel() + agent = Agent(name="test", model=model) + context = {"request_id": "ctx-123"} + + model.set_next_output([get_text_message("Hello")]) + result = await run_agent_async( + runner_method, + agent, + "Hi there", + session=session, + context=context, + ) + + assert result.final_output == "Hello" + assert session.get_wrappers + assert any(wrapper is not None for wrapper in session.add_wrappers) + assert session.get_wrappers[0] is not None + assert session.get_wrappers[0].context == context + assert session.add_wrappers[0] is not None + assert session.add_wrappers[0].context == context + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_legacy_session_signatures_remain_compatible(runner_method): + session = LegacyCompatibleSession() + model = FakeModel() + agent = Agent(name="test", model=model) + + model.set_next_output([get_text_message("Hello")]) + result = await run_agent_async( + runner_method, + agent, + "Hi there", + session=session, + context={"request_id": "legacy"}, + ) + + assert result.final_output == "Hello" + assert session.get_call_count > 0 + assert session.add_call_count > 0 + assert session.items + + +@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) +@pytest.mark.asyncio +async def test_legacy_session_without_limit_keyword_remains_compatible(runner_method): + session = LegacyNoLimitKeywordSession() + model = FakeModel() + agent = Agent(name="test", model=model) + + model.set_next_output([get_text_message("Hello")]) + result = await run_agent_async( + runner_method, + agent, + "Hi there", + session=session, + context={"request_id": "legacy-no-limit"}, + ) + + assert result.final_output == "Hello" + assert session.get_call_count > 0 + assert session.add_call_count > 0 + assert session.items + + +@pytest.mark.asyncio +async def test_get_items_preserves_default_limit_when_none_is_unset() -> None: + session = DefaultLimitedSession() + session.items = [ + {"role": "user", "content": "one"}, + {"role": "assistant", "content": "two"}, + ] + + prepared_input, _ = await prepare_input_with_session( + "new", + session, + session_input_callback=None, + ) + + assert isinstance(prepared_input, list) + first = prepared_input[0] + second = prepared_input[1] + assert isinstance(first, dict) and first.get("content") == "two" + assert isinstance(second, dict) and second.get("content") == "new" + assert session.get_call_count == 1 + + +@pytest.mark.asyncio +async def test_get_items_with_limit_preserves_wrapper_only_delegate_shape() -> None: + session = WrapperOnlySession() + session.items = [ + {"role": "user", "content": "one"}, + {"role": "assistant", "content": "two"}, + ] + wrapper = RunContextWrapper(context={"request_id": "wrapper-only"}) + + prepared_input, _ = await prepare_input_with_session( + "new", + cast(SessionABC, session), + session_input_callback=None, + session_settings=SessionSettings(limit=1), + wrapper=wrapper, + ) + + assert isinstance(prepared_input, list) + first = prepared_input[0] + second = prepared_input[1] + assert isinstance(first, dict) and first.get("content") == "two" + assert isinstance(second, dict) and second.get("content") == "new" + assert session.get_wrappers == [wrapper] + + @pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"]) @pytest.mark.asyncio async def test_session_memory_different_sessions_parametrized(runner_method): @@ -545,7 +808,7 @@ async def test_session_add_items_exception_propagates_in_streamed(): """ session = SQLiteSession("test_exception_session") - async def _failing_add_items(_items): + async def _failing_add_items(_items, wrapper=None): raise RuntimeError("Simulated session.add_items failure") session.add_items = _failing_add_items # type: ignore[method-assign] diff --git a/tests/utils/simple_session.py b/tests/utils/simple_session.py index 94bcc97e9e..658b81aad4 100644 --- a/tests/utils/simple_session.py +++ b/tests/utils/simple_session.py @@ -1,10 +1,11 @@ from __future__ import annotations -from typing import cast +from typing import Any, cast from agents.items import TResponseInputItem from agents.memory.session import Session from agents.memory.session_settings import SessionSettings +from agents.run_context import RunContextWrapper class SimpleListSession(Session): @@ -24,14 +25,22 @@ def __init__( # Mirror saved_items used by some tests for inspection. self.saved_items: list[TResponseInputItem] = self._items - async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + async def get_items( + self, + limit: int | None = None, + wrapper: RunContextWrapper[Any] | None = None, + ) -> list[TResponseInputItem]: if limit is None: return list(self._items) if limit <= 0: return [] return self._items[-limit:] - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: self._items.extend(items) async def pop_item(self) -> TResponseInputItem | None: @@ -70,7 +79,11 @@ def __init__( super().__init__(session_id=session_id, history=history) self._ignore_ids_for_matching = True - async def add_items(self, items: list[TResponseInputItem]) -> None: + async def add_items( + self, + items: list[TResponseInputItem], + wrapper: RunContextWrapper[Any] | None = None, + ) -> None: sanitized: list[TResponseInputItem] = [] for item in items: if isinstance(item, dict): @@ -79,4 +92,4 @@ async def add_items(self, items: list[TResponseInputItem]) -> None: sanitized.append(cast(TResponseInputItem, clean)) else: sanitized.append(item) - await super().add_items(sanitized) + await super().add_items(sanitized, wrapper=wrapper)