Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions docs/sessions/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,52 @@ result = await Runner.run(
)
```

### Accessing the run context in custom sessions

Custom sessions can opt into receiving the current run's [`RunContextWrapper`][agents.run_context.RunContextWrapper] by accepting a keyword-only `wrapper` parameter on `get_items` and `add_items`. The runner passes the wrapper only when the session's signature accepts it, so existing session implementations keep working unchanged:

```python
from typing import Any

from agents.memory.session import SessionABC
from agents.items import TResponseInputItem
from agents.run_context import RunContextWrapper

class ContextAwareSession(SessionABC):
"""Session that scopes storage by data from the run context."""

def __init__(self, session_id: str):
self.session_id = session_id

async def get_items(
self,
limit: int | None = None,
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
# Use wrapper.context (e.g. a user ID) to scope retrieval.
user_id = wrapper.context.user_id if wrapper is not None else None
return await self._load_items(user_id, limit)

async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
# Persist items together with data from the run context.
user_id = wrapper.context.user_id if wrapper is not None else None
await self._store_items(user_id, items)

async def pop_item(self) -> TResponseInputItem | None:
...

async def clear_session(self) -> None:
...
```

The `wrapper` parameter may be `None`, for example when session methods are called directly rather than through the runner, so implementations should always handle that case. Sessions that accept `**kwargs` on these methods also receive the wrapper through them.

## Community session implementations

The community has developed additional session implementations:
Expand Down
15 changes: 13 additions & 2 deletions examples/memory/file_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from agents.memory.session import Session
from agents.memory.session_settings import SessionSettings
from agents.run_context import RunContextWrapper


class FileSession(Session):
Expand Down Expand Up @@ -43,14 +44,24 @@ async def get_session_id(self) -> str:
"""Return the session id, creating one if needed."""
return await self._ensure_session_id()

async def get_items(self, limit: int | None = None) -> list[Any]:
async def get_items(
self,
limit: int | None = None,
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[Any]:
session_id = await self._ensure_session_id()
items = await self._read_items(session_id)
if limit is not None and limit >= 0:
return items[-limit:]
return items

async def add_items(self, items: list[Any]) -> None:
async def add_items(
self,
items: list[Any],
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
if not items:
return
session_id = await self._ensure_session_id()
Expand Down
10 changes: 9 additions & 1 deletion src/agents/extensions/memory/advanced_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ...items import TResponseInputItem
from ...memory import SQLiteSession
from ...memory.session_settings import SessionSettings, resolve_session_limit
from ...run_context import RunContextWrapper


class AdvancedSQLiteSession(SQLiteSession):
Expand Down Expand Up @@ -121,7 +122,12 @@ def _init_structure_tables(self):

conn.commit()

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Add items to the session.

Args:
Expand Down Expand Up @@ -160,6 +166,8 @@ async def get_items(
self,
limit: int | None = None,
branch_id: str | None = None,
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
"""Get items from current or specified branch.

Expand Down
17 changes: 14 additions & 3 deletions src/agents/extensions/memory/async_sqlite_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import cast
from typing import Any, cast

import aiosqlite

from ...items import TResponseInputItem
from ...memory import SessionABC
from ...memory.session_settings import SessionSettings, resolve_session_limit
from ...run_context import RunContextWrapper


class AsyncSQLiteSession(SessionABC):
Expand Down Expand Up @@ -106,7 +107,12 @@ async def _locked_connection(self) -> AsyncIterator[aiosqlite.Connection]:
conn = await self._get_connection()
yield conn

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self,
limit: int | None = None,
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -156,7 +162,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:

return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down
15 changes: 13 additions & 2 deletions src/agents/extensions/memory/dapr_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import time
from typing import Any, Final, Literal

from ...run_context import RunContextWrapper
from ._optional_imports import raise_optional_dependency_error

try:
Expand Down Expand Up @@ -250,7 +251,12 @@ async def _handle_concurrency_conflict(self, error: Exception, attempt: int) ->
# Session protocol implementation
# ------------------------------------------------------------------

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self,
limit: int | None = None,
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -289,7 +295,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
continue
return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down
39 changes: 34 additions & 5 deletions src/agents/extensions/memory/encrypt_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@
from typing_extensions import TypedDict

from ...items import TResponseInputItem
from ...memory.session import SessionABC
from ...memory.session import SessionABC, session_method_accepts_wrapper
from ...memory.session_settings import SessionSettings, resolve_session_limit
from ...run_context import RunContextWrapper


class EncryptedEnvelope(TypedDict):
Expand Down Expand Up @@ -180,24 +181,52 @@ def _unwrap_valid_items(
valid_items.append(item)
return valid_items

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def _get_underlying_items(
self, limit: int | None, wrapper: RunContextWrapper[Any] | None
) -> list[TResponseInputItem]:
# Forward the wrapper only when the underlying session opts in, so wrapping older
# custom sessions keeps working unchanged.
if wrapper is not None and session_method_accepts_wrapper(
self.underlying_session.get_items
):
return await self.underlying_session.get_items(limit, wrapper=wrapper)
return await self.underlying_session.get_items(limit)

async def get_items(
self,
limit: int | None = None,
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
effective_limit = resolve_session_limit(limit, self.session_settings)
if effective_limit is not None and effective_limit > 0:
window = effective_limit
while True:
encrypted_items = await self.underlying_session.get_items(window)
encrypted_items = await self._get_underlying_items(window, wrapper)
valid_items = self._unwrap_valid_items(encrypted_items)
if len(valid_items) >= effective_limit:
return valid_items[-effective_limit:]
if len(encrypted_items) < window:
return valid_items
window *= 2

encrypted_items = await self.underlying_session.get_items(limit)
encrypted_items = await self._get_underlying_items(limit, wrapper)
return self._unwrap_valid_items(encrypted_items)

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
wrapped: list[EncryptedEnvelope] = [self._wrap(it) for it in items]
if wrapper is not None and session_method_accepts_wrapper(
self.underlying_session.add_items
):
await self.underlying_session.add_items(
cast(list[TResponseInputItem], wrapped), wrapper=wrapper
)
return
await self.underlying_session.add_items(cast(list[TResponseInputItem], wrapped))

async def pop_item(self) -> TResponseInputItem | None:
Expand Down
15 changes: 13 additions & 2 deletions src/agents/extensions/memory/mongodb_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from datetime import datetime, timezone
from typing import Any

from ...run_context import RunContextWrapper
from ._optional_imports import raise_optional_dependency_error

try:
Expand Down Expand Up @@ -247,7 +248,12 @@ async def _deserialize_item(self, raw: str) -> TResponseInputItem:
# Session protocol implementation
# ------------------------------------------------------------------

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self,
limit: int | None = None,
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -289,7 +295,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:

return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down
15 changes: 13 additions & 2 deletions src/agents/extensions/memory/redis_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import time
from typing import Any

from ...run_context import RunContextWrapper
from ._optional_imports import raise_optional_dependency_error

try:
Expand Down Expand Up @@ -145,7 +146,12 @@ async def _set_ttl_if_configured(self, *keys: str) -> None:
# Session protocol implementation
# ------------------------------------------------------------------

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self,
limit: int | None = None,
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -184,7 +190,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:

return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down
15 changes: 13 additions & 2 deletions src/agents/extensions/memory/sqlalchemy_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from ...items import TResponseInputItem
from ...memory.session import SessionABC
from ...memory.session_settings import SessionSettings, resolve_session_limit
from ...run_context import RunContextWrapper


class SQLAlchemySession(SessionABC):
Expand Down Expand Up @@ -274,7 +275,12 @@ async def _ensure_tables(self) -> None:
finally:
self._init_lock.release()

async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
async def get_items(
self,
limit: int | None = None,
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> list[TResponseInputItem]:
"""Retrieve the conversation history for this session.

Args:
Expand Down Expand Up @@ -326,7 +332,12 @@ async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]:
continue
return items

async def add_items(self, items: list[TResponseInputItem]) -> None:
async def add_items(
self,
items: list[TResponseInputItem],
*,
wrapper: RunContextWrapper[Any] | None = None,
) -> None:
"""Add new items to the conversation history.

Args:
Expand Down
Loading