diff --git a/python/packages/mem0/agent_framework_mem0/_provider.py b/python/packages/mem0/agent_framework_mem0/_provider.py index 6d726f3a0f..48e508f411 100644 --- a/python/packages/mem0/agent_framework_mem0/_provider.py +++ b/python/packages/mem0/agent_framework_mem0/_provider.py @@ -150,11 +150,12 @@ async def invoking(self, messages: ChatMessage | MutableSequence[ChatMessage], * if not input_text.strip(): return Context(messages=None) + # Build filters from init parameters + filters = self._build_filters() + search_response: MemorySearchResponse_v1_1 | MemorySearchResponse_v2 = await self.mem0_client.search( # type: ignore[misc] query=input_text, - user_id=self.user_id, - agent_id=self.agent_id, - run_id=self._per_operation_thread_id if self.scope_to_per_operation_thread_id else self.thread_id, + filters=filters, ) # Depending on the API version, the response schema varies slightly @@ -185,6 +186,29 @@ def _validate_filters(self) -> None: "At least one of the filters: agent_id, user_id, application_id, or thread_id is required." ) + def _build_filters(self) -> dict[str, Any]: + """Build search filters from initialization parameters. + + Returns: + Filter dictionary for mem0 v2 search API containing initialization parameters. + In the v2 API, filters holds the user_id, agent_id, run_id (thread_id), and app_id + (application_id) which are required for scoping memory search operations. + """ + filters: dict[str, Any] = {} + + if self.user_id: + filters["user_id"] = self.user_id + if self.agent_id: + filters["agent_id"] = self.agent_id + if self.scope_to_per_operation_thread_id and self._per_operation_thread_id: + filters["run_id"] = self._per_operation_thread_id + elif self.thread_id: + filters["run_id"] = self.thread_id + if self.application_id: + filters["app_id"] = self.application_id + + return filters + def _validate_per_operation_thread_id(self, thread_id: str | None) -> None: """Validates that a new thread ID doesn't conflict with an existing one when scoped. diff --git a/python/packages/mem0/tests/test_mem0_context_provider.py b/python/packages/mem0/tests/test_mem0_context_provider.py index 4c1be141dc..7464aad913 100644 --- a/python/packages/mem0/tests/test_mem0_context_provider.py +++ b/python/packages/mem0/tests/test_mem0_context_provider.py @@ -338,7 +338,7 @@ async def test_model_invoking_single_message(self, mock_mem0_client: AsyncMock) mock_mem0_client.search.assert_called_once() call_args = mock_mem0_client.search.call_args assert call_args.kwargs["query"] == "What's the weather?" - assert call_args.kwargs["user_id"] == "user123" + assert call_args.kwargs["filters"] == {"user_id": "user123"} assert isinstance(context, Context) expected_instructions = ( @@ -373,8 +373,7 @@ async def test_model_invoking_with_agent_id(self, mock_mem0_client: AsyncMock) - await provider.invoking(message) call_args = mock_mem0_client.search.call_args - assert call_args.kwargs["agent_id"] == "agent123" - assert call_args.kwargs["user_id"] is None + assert call_args.kwargs["filters"] == {"agent_id": "agent123"} async def test_model_invoking_with_scope_to_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None: """Test invoking with scope_to_per_operation_thread_id enabled.""" @@ -392,7 +391,7 @@ async def test_model_invoking_with_scope_to_per_operation_thread_id(self, mock_m await provider.invoking(message) call_args = mock_mem0_client.search.call_args - assert call_args.kwargs["run_id"] == "operation_thread" + assert call_args.kwargs["filters"] == {"user_id": "user123", "run_id": "operation_thread"} async def test_model_invoking_no_memories_returns_none_instructions(self, mock_mem0_client: AsyncMock) -> None: """Test that no memories returns context with None instructions.""" @@ -510,3 +509,87 @@ def test_validate_per_operation_thread_id_disabled_scope(self, mock_mem0_client: # Should not raise exception even with different thread ID provider._validate_per_operation_thread_id("different_thread") + + +class TestMem0ProviderBuildFilters: + """Test the _build_filters method.""" + + def test_build_filters_with_user_id_only(self, mock_mem0_client: AsyncMock) -> None: + """Test building filters with only user_id.""" + provider = Mem0Provider(user_id="user123", mem0_client=mock_mem0_client) + + filters = provider._build_filters() + assert filters == {"user_id": "user123"} + + def test_build_filters_with_all_parameters(self, mock_mem0_client: AsyncMock) -> None: + """Test building filters with all initialization parameters.""" + provider = Mem0Provider( + user_id="user123", + agent_id="agent456", + thread_id="thread789", + application_id="app999", + mem0_client=mock_mem0_client, + ) + + filters = provider._build_filters() + assert filters == { + "user_id": "user123", + "agent_id": "agent456", + "run_id": "thread789", + "app_id": "app999", + } + + def test_build_filters_excludes_none_values(self, mock_mem0_client: AsyncMock) -> None: + """Test that None values are excluded from filters.""" + provider = Mem0Provider( + user_id="user123", + agent_id=None, + thread_id=None, + application_id=None, + mem0_client=mock_mem0_client, + ) + + filters = provider._build_filters() + assert filters == {"user_id": "user123"} + assert "agent_id" not in filters + assert "run_id" not in filters + assert "app_id" not in filters + + def test_build_filters_with_per_operation_thread_id(self, mock_mem0_client: AsyncMock) -> None: + """Test that per-operation thread ID takes precedence over base thread_id.""" + provider = Mem0Provider( + user_id="user123", + thread_id="base_thread", + scope_to_per_operation_thread_id=True, + mem0_client=mock_mem0_client, + ) + provider._per_operation_thread_id = "operation_thread" + + filters = provider._build_filters() + assert filters == { + "user_id": "user123", + "run_id": "operation_thread", # Per-operation thread, not base_thread + } + + def test_build_filters_uses_base_thread_when_no_per_operation(self, mock_mem0_client: AsyncMock) -> None: + """Test that base thread_id is used when per-operation thread is not set.""" + provider = Mem0Provider( + user_id="user123", + thread_id="base_thread", + scope_to_per_operation_thread_id=True, + mem0_client=mock_mem0_client, + ) + # _per_operation_thread_id is None + + filters = provider._build_filters() + assert filters == { + "user_id": "user123", + "run_id": "base_thread", # Falls back to base thread_id + } + + def test_build_filters_returns_empty_dict_when_no_parameters(self, mock_mem0_client: AsyncMock) -> None: + """Test that _build_filters returns an empty dict when no parameters are set.""" + provider = Mem0Provider(mem0_client=mock_mem0_client) + + filters = provider._build_filters() + assert filters == {} diff --git a/python/samples/getting_started/context_providers/mem0/mem0_basic.py b/python/samples/getting_started/context_providers/mem0/mem0_basic.py index 0c2252d66a..5fef82d390 100644 --- a/python/samples/getting_started/context_providers/mem0/mem0_basic.py +++ b/python/samples/getting_started/context_providers/mem0/mem0_basic.py @@ -54,6 +54,13 @@ async def main() -> None: result = await agent.run(query) print(f"Agent: {result}\n") + # Mem0 processes and indexes memories asynchronously. + # Wait for memories to be indexed before querying in a new thread. + # In production, consider implementing retry logic or using Mem0's + # eventual consistency handling instead of a fixed delay. + print("Waiting for memories to be processed...") + await asyncio.sleep(12) # Empirically determined delay for Mem0 indexing + print("\nRequest within a new thread:") # Create a new thread for the agent. # The new thread has no context of the previous conversation.