diff --git a/src/google/adk/sessions/database_session_service.py b/src/google/adk/sessions/database_session_service.py index 2ccd60083..b8587c667 100644 --- a/src/google/adk/sessions/database_session_service.py +++ b/src/google/adk/sessions/database_session_service.py @@ -444,6 +444,7 @@ async def get_session( storage_events = ( session_factory.query(StorageEvent) .filter(StorageEvent.session_id == storage_session.id) + .filter(StorageEvent.user_id == user_id) .filter(timestamp_filter) .order_by(StorageEvent.timestamp.desc()) .limit( diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index ec93caafb..5f1e2f487 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -117,6 +117,7 @@ async def test_session_state(service_type): app_name = 'my_app' user_id_1 = 'user1' user_id_2 = 'user2' + user_id_malicious = 'malicious' session_id_11 = 'session11' session_id_12 = 'session12' session_id_2 = 'session2' @@ -139,6 +140,10 @@ async def test_session_state(service_type): app_name=app_name, user_id=user_id_2, session_id=session_id_2 ) + await session_service.create_session( + app_name=app_name, user_id=user_id_malicious, session_id=session_id_11 + ) + assert session_11.state.get('key11') == 'value11' event = Event( @@ -187,6 +192,13 @@ async def test_session_state(service_type): assert session_11.state.get('user:key1') == 'value1' assert not session_11.state.get('temp:key') + # Make sure a malicious user can obtain a session and events not belonging to them + session_mismatch = await session_service.get_session( + app_name=app_name, user_id=user_id_malicious, session_id=session_id_11 + ) + + assert len(session_mismatch.events) == 0 + @pytest.mark.asyncio @pytest.mark.parametrize(