Skip to content

Commit b372dcc

Browse files
Feat: Add mem0_memory Support (Vector) for Neptune Analytics (#262)
* Add config Signed-off-by: Andy Kwok <[email protected]> * Update llm Signed-off-by: Andy Kwok <[email protected]> * Add error msg Signed-off-by: Andy Kwok <[email protected]> * Update vector config handling Signed-off-by: Andy Kwok <[email protected]> * Update Graph store Signed-off-by: Andy Kwok <[email protected]> * Remove debug Signed-off-by: Andy Kwok <[email protected]> * Fix lint Signed-off-by: Andy Kwok <[email protected]> * Update test Signed-off-by: Andy Kwok <[email protected]> * Update tests Signed-off-by: Andy Kwok <[email protected]> * Update lint Signed-off-by: Andy Kwok <[email protected]> * Update test Signed-off-by: Andy Kwok <[email protected]> * Remove default config Signed-off-by: Andy Kwok <[email protected]> * Update config Signed-off-by: Andy Kwok <[email protected]> * Update display of stored message Signed-off-by: Andy Kwok <[email protected]> * Consolidate var Signed-off-by: Andy Kwok <[email protected]> * Update test Signed-off-by: Andy Kwok <[email protected]> * Update src/strands_tools/mem0_memory.py Co-authored-by: Andrew Carbonetto <[email protected]> * Update table layout Signed-off-by: Andy Kwok <[email protected]> * Update wording Signed-off-by: Andy Kwok <[email protected]> * Rename method Signed-off-by: Andy Kwok <[email protected]> * Update rename Signed-off-by: Andy Kwok <[email protected]> * Pach dict handling Signed-off-by: Andy Kwok <[email protected]> --------- Signed-off-by: Andy Kwok <[email protected]> Co-authored-by: Andrew Carbonetto <[email protected]>
1 parent 319c8b0 commit b372dcc

File tree

2 files changed

+113
-39
lines changed

2 files changed

+113
-39
lines changed

src/strands_tools/mem0_memory.py

Lines changed: 98 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -162,19 +162,6 @@ class Mem0ServiceClient:
162162
"max_tokens": int(os.environ.get("MEM0_LLM_MAX_TOKENS", 2000)),
163163
},
164164
},
165-
"vector_store": {
166-
"provider": "opensearch",
167-
"config": {
168-
"port": 443,
169-
"collection_name": os.environ.get("OPENSEARCH_COLLECTION", "mem0"),
170-
"host": os.environ.get("OPENSEARCH_HOST"),
171-
"embedding_model_dims": 1024,
172-
"connection_class": RequestsHttpConnection,
173-
"pool_maxsize": 20,
174-
"use_ssl": True,
175-
"verify_certs": True,
176-
},
177-
},
178165
}
179166

180167
def __init__(self, config: Optional[Dict] = None):
@@ -204,19 +191,32 @@ def _initialize_client(self, config: Optional[Dict] = None) -> Any:
204191
logger.debug("Using Mem0 Platform backend (MemoryClient)")
205192
return MemoryClient()
206193

207-
if os.environ.get("NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER"):
208-
logger.debug("Using Neptune Analytics graph backend (Mem0Memory with Neptune Analytics)")
209-
config = self._configure_neptune_analytics_backend(config)
194+
if os.environ.get("NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER") and os.environ.get("OPENSEARCH_HOST"):
195+
raise RuntimeError("""Conflicting backend configurations:
196+
Only one environment variable of NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER or OPENSEARCH_HOST can be set.""")
210197

198+
# Vector search providers
211199
if os.environ.get("OPENSEARCH_HOST"):
212200
logger.debug("Using OpenSearch backend (Mem0Memory with OpenSearch)")
213-
return self._initialize_opensearch_client(config)
201+
merged_config = self._append_opensearch_config(config)
214202

215-
logger.debug("Using FAISS backend (Mem0Memory with FAISS)")
216-
return self._initialize_faiss_client(config)
203+
elif os.environ.get("NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER"):
204+
logger.debug("Using Neptune Analytics vector backend (Mem0Memory with Neptune Analytics)")
205+
merged_config = self._append_neptune_analytics_vector_config(config)
217206

218-
def _configure_neptune_analytics_backend(self, config: Optional[Dict] = None) -> Dict:
219-
"""Initialize a Mem0 client with Neptune Analytics graph backend.
207+
else:
208+
logger.debug("Using FAISS backend (Mem0Memory with FAISS)")
209+
merged_config = self._append_faiss_config(config)
210+
211+
# Graph backend providers
212+
if os.environ.get("NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER"):
213+
logger.debug("Using Neptune Analytics graph backend (Mem0Memory with Neptune Analytics)")
214+
merged_config = self._append_neptune_analytics_graph_config(merged_config)
215+
216+
return Mem0Memory.from_config(config_dict=merged_config)
217+
218+
def _append_neptune_analytics_vector_config(self, config: Optional[Dict] = None) -> Dict:
219+
"""Update incoming configuration dictionary to include the configuration of Neptune Analytics vector backend.
220220
221221
Args:
222222
config: Optional configuration dictionary to override defaults.
@@ -225,21 +225,40 @@ def _configure_neptune_analytics_backend(self, config: Optional[Dict] = None) ->
225225
An configuration dict with graph backend.
226226
"""
227227
config = config or {}
228-
config["graph_store"] = {
228+
config["vector_store"] = {
229229
"provider": "neptune",
230-
"config": {"endpoint": f"neptune-graph://{os.environ.get('NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER')}"},
230+
"config": {
231+
"collection_name": os.environ.get("NEPTUNE_ANALYTICS_VECTOR_COLLECTION", "mem0"),
232+
"endpoint": f"neptune-graph://{os.environ.get('NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER')}",
233+
},
231234
}
232-
return config
235+
return self._merge_config(config)
233236

234-
def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Memory:
235-
"""Initialize a Mem0 client with OpenSearch backend.
237+
def _append_opensearch_config(self, config: Optional[Dict] = None) -> Dict:
238+
"""Update incoming configuration dictionary to include the configuration of OpenSearch vector backend.
236239
237240
Args:
238241
config: Optional configuration dictionary to override defaults.
239242
240243
Returns:
241244
An initialized Mem0Memory instance configured for OpenSearch.
242245
"""
246+
# Add vector portion of the config
247+
config = config or {}
248+
config["vector_store"] = {
249+
"provider": "opensearch",
250+
"config": {
251+
"port": 443,
252+
"collection_name": os.environ.get("OPENSEARCH_COLLECTION", "mem0"),
253+
"host": os.environ.get("OPENSEARCH_HOST"),
254+
"embedding_model_dims": 1024,
255+
"connection_class": RequestsHttpConnection,
256+
"pool_maxsize": 20,
257+
"use_ssl": True,
258+
"verify_certs": True,
259+
},
260+
}
261+
243262
# Set up AWS region
244263
self.region = os.environ.get("AWS_REGION", "us-west-2")
245264
if not os.environ.get("AWS_REGION"):
@@ -254,10 +273,11 @@ def _initialize_opensearch_client(self, config: Optional[Dict] = None) -> Mem0Me
254273
merged_config = self._merge_config(config)
255274
merged_config["vector_store"]["config"].update({"http_auth": auth, "host": os.environ["OPENSEARCH_HOST"]})
256275

257-
return Mem0Memory.from_config(config_dict=merged_config)
276+
return merged_config
277+
278+
def _append_faiss_config(self, config: Optional[Dict] = None) -> Dict:
279+
"""Update incoming configuration dictionary to include the configuration of FAISS vector backend.
258280
259-
def _initialize_faiss_client(self, config: Optional[Dict] = None) -> Mem0Memory:
260-
"""Initialize a Mem0 client with FAISS backend.
261281
262282
Args:
263283
config: Optional configuration dictionary to override defaults.
@@ -284,8 +304,22 @@ def _initialize_faiss_client(self, config: Optional[Dict] = None) -> Mem0Memory:
284304
"path": "/tmp/mem0_384_faiss",
285305
},
286306
}
307+
return merged_config
287308

288-
return Mem0Memory.from_config(config_dict=merged_config)
309+
def _append_neptune_analytics_graph_config(self, config: Dict) -> Dict:
310+
"""Update incoming configuration dictionary to include the configuration of Neptune Analytics graph backend.
311+
312+
Args:
313+
config: Configuration dictionary to add Neptune Analytics graph backend
314+
315+
Returns:
316+
An configuration dict with graph backend.
317+
"""
318+
config["graph_store"] = {
319+
"provider": "neptune",
320+
"config": {"endpoint": f"neptune-graph://{os.environ.get('NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER')}"},
321+
}
322+
return config
289323

290324
def _merge_config(self, config: Optional[Dict] = None) -> Dict:
291325
"""Merge user-provided configuration with default configuration.
@@ -462,9 +496,9 @@ def format_retrieve_graph_response(memories: List[Dict]) -> Panel:
462496
)
463497

464498
table = Table(title="Search Results", show_header=True, header_style="bold magenta")
465-
table.add_column("Source", style="cyan")
466-
table.add_column("Relationship", style="yellow", width=50)
467-
table.add_column("Destination", style="green")
499+
table.add_column("Source", style="cyan", width=25)
500+
table.add_column("Relationship", style="yellow", width=45)
501+
table.add_column("Destination", style="green", width=30)
468502

469503
for memory in memories:
470504
source = memory.get("source", "N/A")
@@ -482,9 +516,9 @@ def format_list_graph_response(memories: List[Dict]) -> Panel:
482516
return Panel("No graph memories found.", title="[bold yellow]No Memories", border_style="yellow")
483517

484518
table = Table(title="Graph Memories", show_header=True, header_style="bold magenta")
485-
table.add_column("Source", style="cyan")
486-
table.add_column("Relationship", style="yellow", width=50)
487-
table.add_column("Target", style="green")
519+
table.add_column("Source", style="cyan", width=25)
520+
table.add_column("Relationship", style="yellow", width=45)
521+
table.add_column("Target", style="green", width=30)
488522

489523
for memory in memories:
490524
source = memory.get("source", "N/A")
@@ -545,6 +579,26 @@ def format_store_response(results: List[Dict]) -> Panel:
545579
return Panel(table, title="[bold green]Memory Stored", border_style="green")
546580

547581

582+
def format_store_graph_response(memories: List[Dict]) -> Panel:
583+
"""Format store response for graph data"""
584+
if not memories:
585+
return Panel("No graph memories stored.", title="[bold yellow]No Memories Stored", border_style="yellow")
586+
587+
table = Table(title="Graph Memories Stored", show_header=True, header_style="bold magenta")
588+
table.add_column("Source", style="cyan", width=25)
589+
table.add_column("Relationship", style="yellow", width=45)
590+
table.add_column("Target", style="green", width=30)
591+
592+
for memory in memories:
593+
source = memory[0].get("source", "N/A")
594+
relationship = memory[0].get("relationship", "N/A")
595+
destination = memory[0].get("target", "N/A")
596+
597+
table.add_row(source, relationship, destination)
598+
599+
return Panel(table, title="[bold green]Memories Stored (Graph)", border_style="green")
600+
601+
548602
def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
549603
"""
550604
Memory management tool for storing, retrieving, and managing memories in Mem0.
@@ -656,6 +710,14 @@ def mem0_memory(tool: ToolUse, **kwargs: Any) -> ToolResult:
656710
if results_list:
657711
panel = format_store_response(results_list)
658712
console.print(panel)
713+
714+
# Process graph relations (If any)
715+
if "relations" in results:
716+
relationships_list = results.get("relations").get("added_entities", [])
717+
results_list.extend(relationships_list)
718+
panel_graph = format_store_graph_response(relationships_list)
719+
console.print(panel_graph)
720+
659721
return ToolResult(
660722
toolUseId=tool_use_id,
661723
status="success",

tests/test_mem0.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,13 +424,25 @@ def test_mem0_service_client_init(mock_opensearch, mock_mem0_memory, mock_sessio
424424
client = Mem0ServiceClient()
425425
assert client.region == os.environ.get("AWS_REGION", "us-west-2")
426426

427-
# Test with optional Graph backend
427+
# Test with conflict scenario
428428
with patch.dict(
429429
os.environ,
430-
{"OPENSEARCH_HOST": "test.opensearch.amazonaws.com", "NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER": "g-5aaaaa1234"},
430+
{
431+
"OPENSEARCH_HOST": "test.opensearch.amazonaws.com",
432+
"NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER": "g-5aaaaa1234",
433+
},
434+
):
435+
with pytest.raises(RuntimeError):
436+
Mem0ServiceClient()
437+
438+
# Test with Neptune Analytics for both vector and graph
439+
with patch.dict(
440+
os.environ,
441+
{
442+
"NEPTUNE_ANALYTICS_GRAPH_IDENTIFIER": "g-5aaaaa1234",
443+
},
431444
):
432445
client = Mem0ServiceClient()
433-
assert client.region == os.environ.get("AWS_REGION", "us-west-2")
434446
assert client.mem0 is not None
435447

436448
# Test with custom config (OpenSearch)

0 commit comments

Comments
 (0)