Skip to content

Commit 558c83a

Browse files
authored
INTPYTHON-728 Fix handling of metadata in saver (#195)
1 parent a087b11 commit 558c83a

File tree

4 files changed

+704
-679
lines changed

4 files changed

+704
-679
lines changed

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/aio.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,8 @@ async def aput(
357357
checkpoint_ns = config["configurable"]["checkpoint_ns"]
358358
checkpoint_id = checkpoint["id"]
359359
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
360+
metadata = metadata.copy()
361+
metadata.update(config.get("metadata", {}))
360362
doc = {
361363
"parent_checkpoint_id": config["configurable"].get("checkpoint_id"),
362364
"type": type_,
@@ -368,6 +370,7 @@ async def aput(
368370
"checkpoint_ns": checkpoint_ns,
369371
"checkpoint_id": checkpoint_id,
370372
}
373+
371374
if self.ttl:
372375
doc["created_at"] = datetime.now()
373376
# Perform your operations here

libs/langgraph-checkpoint-mongodb/langgraph/checkpoint/mongodb/saver.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,8 @@ def put(
376376
checkpoint_ns = config["configurable"]["checkpoint_ns"]
377377
checkpoint_id = checkpoint["id"]
378378
type_, serialized_checkpoint = self.serde.dumps_typed(checkpoint)
379+
metadata = metadata.copy()
380+
metadata.update(config.get("metadata", {}))
379381
doc = {
380382
"parent_checkpoint_id": config["configurable"].get("checkpoint_id"),
381383
"type": type_,

libs/langgraph-checkpoint-mongodb/tests/integration_tests/test_highlevel_graph.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def fanout(state: OverallState) -> list:
8787
return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]
8888

8989
parentgraph = StateGraph(OverallState)
90-
parentgraph.add_node("generate_joke", subgraphc) # type: ignore[arg-type]
90+
parentgraph.add_node("generate_joke", subgraphc)
9191
parentgraph.add_conditional_edges(START, fanout)
9292
parentgraph.add_edge("generate_joke", END)
9393
return parentgraph
@@ -176,3 +176,69 @@ async def test_fanout(
176176
)
177177
end = time.monotonic()
178178
print(f"{cname}: {end - start:.4f} seconds")
179+
180+
181+
async def test_custom_properties_async(
182+
checkpointer_mongodb: MongoDBSaver, checkpointer_mongodb_async: AsyncMongoDBSaver
183+
) -> None:
184+
# Create the state graph
185+
state_graph = fanout_to_subgraph()
186+
187+
# Define configuration with thread ID and assistant ID
188+
assistant_id = "456"
189+
user_id = "789"
190+
config: RunnableConfig = {
191+
"configurable": {
192+
"thread_id": "123",
193+
"assistant_id": assistant_id,
194+
"user_id": user_id,
195+
}
196+
}
197+
198+
# Compile the state graph with the provided checkpointing mechanism
199+
compiled_state_graph = state_graph.compile(checkpointer=checkpointer_mongodb_async)
200+
201+
# Invoke the compiled state graph with user input
202+
await compiled_state_graph.ainvoke(
203+
input={"subjects": [], "step": 0}, # type:ignore[arg-type]
204+
config=config,
205+
stream_mode="values",
206+
debug=False,
207+
)
208+
209+
checkpoint_tuple = await checkpointer_mongodb_async.aget_tuple(config)
210+
assert checkpoint_tuple is not None
211+
assert checkpoint_tuple.metadata["user_id"] == user_id
212+
assert checkpoint_tuple.metadata["assistant_id"] == assistant_id
213+
214+
215+
def test_custom_properties(checkpointer_mongodb: MongoDBSaver) -> None:
216+
# Create the state graph
217+
state_graph = fanout_to_subgraph()
218+
219+
# Define configuration with thread ID and assistant ID
220+
assistant_id = "456"
221+
user_id = "789"
222+
config: RunnableConfig = {
223+
"configurable": {
224+
"thread_id": "123",
225+
"assistant_id": assistant_id,
226+
"user_id": user_id,
227+
}
228+
}
229+
230+
# Compile the state graph with the provided checkpointing mechanism
231+
compiled_state_graph = state_graph.compile(checkpointer=checkpointer_mongodb)
232+
233+
# Invoke the compiled state graph with user input
234+
compiled_state_graph.invoke(
235+
input={"subjects": [], "step": 0}, # type:ignore[arg-type]
236+
config=config,
237+
stream_mode="values",
238+
debug=False,
239+
)
240+
241+
checkpoint_tuple = checkpointer_mongodb.get_tuple(config)
242+
assert checkpoint_tuple is not None
243+
assert checkpoint_tuple.metadata["user_id"] == user_id
244+
assert checkpoint_tuple.metadata["assistant_id"] == assistant_id

0 commit comments

Comments
 (0)