Skip to content

Commit 39a6009

Browse files
committed
Make graph.run(...) return an instance of GraphRun
1 parent 93260fe commit 39a6009

File tree

9 files changed

+80
-108
lines changed

9 files changed

+80
-108
lines changed

pydantic_ai_slim/pydantic_ai/agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,13 @@ async def main():
337337
)
338338

339339
# Actually run
340-
end_result, _ = await graph.run(
340+
graph_run = await graph.run(
341341
start_node,
342342
state=state,
343343
deps=graph_deps,
344344
infer_name=False,
345345
)
346+
end_result = graph_run.result
346347

347348
# Build final run result
348349
# We don't do any advanced checking if the data is actually from a final result or not

pydantic_graph/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ class Increment(BaseNode):
5050

5151

5252
fives_graph = Graph(nodes=[DivisibleBy5, Increment])
53-
result, history = fives_graph.run_sync(DivisibleBy5(4))
54-
print(result)
53+
graph_run = fives_graph.run_sync(DivisibleBy5(4))
54+
print(graph_run.result)
5555
#> 5
5656
# the full history is quite verbose (see below), so we'll just print the summary
5757
print([item.data_snapshot() for item in history])

pydantic_graph/pydantic_graph/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
from .exceptions import GraphRuntimeError, GraphSetupError
2-
from .graph import Graph, GraphRun, GraphRunner
2+
from .graph import Graph, GraphRun
33
from .nodes import BaseNode, Edge, End, GraphRunContext
44
from .state import EndStep, HistoryStep, NodeStep
55

66
__all__ = (
77
'Graph',
88
'GraphRun',
9-
'GraphRunner',
109
'BaseNode',
1110
'End',
1211
'GraphRunContext',

pydantic_graph/pydantic_graph/graph.py

Lines changed: 51 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),)
3131

3232

33-
__all__ = ('Graph', 'GraphRun', 'GraphRunner')
33+
__all__ = ('Graph', 'GraphRun')
3434

3535
_logfire = logfire_api.Logfire(otel_scope='pydantic-graph')
3636

@@ -133,7 +133,7 @@ def run(
133133
state: StateT = None,
134134
deps: DepsT = None,
135135
infer_name: bool = True,
136-
) -> GraphRunner[StateT, DepsT, T]:
136+
) -> GraphRun[StateT, DepsT, T]:
137137
"""Run the graph from a starting node until it ends.
138138
139139
Args:
@@ -170,7 +170,7 @@ async def main():
170170
if infer_name and self.name is None:
171171
self._infer_name(inspect.currentframe())
172172

173-
return GraphRunner[StateT, DepsT, T](
173+
return GraphRun[StateT, DepsT, T](
174174
self, start_node, history=[], state=state, deps=deps, auto_instrument=self._auto_instrument
175175
)
176176

@@ -181,7 +181,7 @@ def run_sync(
181181
state: StateT = None,
182182
deps: DepsT = None,
183183
infer_name: bool = True,
184-
) -> tuple[T, list[HistoryStep[StateT, T]]]:
184+
) -> GraphRun[StateT, DepsT, T]:
185185
"""Run the graph synchronously.
186186
187187
This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`.
@@ -499,11 +499,10 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None:
499499
return
500500

501501

502-
class GraphRunner(Generic[StateT, DepsT, RunEndT]):
503-
"""An object that can be awaited to perform a graph run.
502+
class GraphRun(Generic[StateT, DepsT, RunEndT]):
503+
"""A stateful run of a graph.
504504
505-
This object can also be used as a contextmanager to get a handle to a specific graph run,
506-
allowing you to iterate over nodes, and possibly perform modifications to the nodes as they are run.
505+
After being entered, can be used like an async generator to listen to / modify nodes as the run is executed.
507506
"""
508507

509508
def __init__(
@@ -517,84 +516,25 @@ def __init__(
517516
auto_instrument: bool,
518517
):
519518
self.graph = graph
520-
self.first_node = first_node
521519
self.history = history
522520
self.state = state
523521
self.deps = deps
524-
525-
self._run: GraphRun[StateT, DepsT, RunEndT] | None = None
526-
527522
self._auto_instrument = auto_instrument
528-
self._span: LogfireSpan | None = None
529-
530-
@property
531-
def run(self) -> GraphRun[StateT, DepsT, RunEndT]:
532-
if self._run is None:
533-
raise exceptions.GraphRuntimeError('GraphRunner has not been awaited yet.')
534-
return self._run
535-
536-
def __await__(self) -> Generator[Any, Any, tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]]:
537-
"""Run the graph until it ends, and return the final result."""
538-
539-
async def _run() -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]:
540-
async with self as run:
541-
self._run = run
542-
async for _next_node in run:
543-
pass
544-
545-
return run.final_result, run.history
546-
547-
return _run().__await__()
548-
549-
async def __aenter__(self) -> GraphRun[StateT, DepsT, RunEndT]:
550-
if self._run is not None:
551-
raise exceptions.GraphRuntimeError('A GraphRunner can only start a GraphRun once.')
552-
553-
if self._auto_instrument:
554-
self._span = logfire_api.span('run graph {graph.name}', graph=self.graph)
555-
self._span.__enter__()
556-
557-
self._run = run = GraphRun(self.graph, self.first_node, history=self.history, state=self.state, deps=self.deps)
558-
return run
559-
560-
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
561-
if self._span is not None:
562-
self._span.__exit__(exc_type, exc_val, exc_tb)
563-
self._span = None # make it more obvious if you try to use it after exiting
564523

565-
566-
class GraphRun(Generic[StateT, DepsT, RunEndT]):
567-
"""A stateful run of a graph.
568-
569-
Can be used like an async generator to listen to / modify nodes as the run is executed.
570-
"""
571-
572-
def __init__(
573-
self,
574-
graph: Graph[StateT, DepsT, RunEndT],
575-
next_node: BaseNode[StateT, DepsT, RunEndT],
576-
*,
577-
history: list[HistoryStep[StateT, RunEndT]],
578-
state: StateT,
579-
deps: DepsT,
580-
):
581-
self.graph = graph
582-
self.next_node = next_node
583-
self.history = history
584-
self.state = state
585-
self.deps = deps
586-
587-
self._final_result: End[RunEndT] | None = None
524+
self._next_node = first_node
525+
self._started: bool = False
526+
self._result: End[RunEndT] | None = None
527+
self._span: LogfireSpan | None = None
588528

589529
@property
590530
def is_ended(self):
591-
return self._final_result is not None
531+
return self._result is not None
592532

593533
@property
594-
def final_result(self) -> RunEndT:
595-
if self._final_result is None:
534+
def result(self) -> RunEndT:
535+
if self._result is None:
596536
raise exceptions.GraphRuntimeError('GraphRun has not ended yet.')
597-
return self._final_result.data
537+
return self._result.data
598538

599539
async def next(
600540
self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T]
@@ -607,16 +547,48 @@ async def next(
607547
next_node = await self.graph.next(node, history, state=state, deps=deps, infer_name=False)
608548

609549
if isinstance(next_node, End):
610-
self._final_result = next_node
550+
self._result = next_node
611551
else:
612-
self.next_node = next_node
552+
self._next_node = next_node
613553
return next_node
614554

555+
def __await__(self) -> Generator[Any, Any, typing_extensions.Self]:
556+
"""Run the graph until it ends, and return the final result."""
557+
558+
async def _run() -> typing_extensions.Self:
559+
with self:
560+
async for _next_node in self:
561+
pass
562+
563+
return self
564+
565+
return _run().__await__()
566+
567+
def __enter__(self) -> typing_extensions.Self:
568+
if self._started:
569+
raise exceptions.GraphRuntimeError('A GraphRun can only be started once.')
570+
571+
if self._auto_instrument:
572+
self._span = logfire_api.span('run graph {graph.name}', graph=self.graph)
573+
self._span.__enter__()
574+
575+
self._started = True
576+
return self
577+
578+
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
579+
if self._span is not None:
580+
self._span.__exit__(exc_type, exc_val, exc_tb)
581+
self._span = None # make it more obvious if you try to use it after exiting
582+
615583
def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]]:
616584
return self
617585

618586
async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]:
619587
"""Use the last returned node as the input to `Graph.next`."""
620-
if self._final_result:
588+
if self._result:
621589
raise StopAsyncIteration
622-
return await self.next(self.next_node)
590+
if not self._started:
591+
raise exceptions.GraphRuntimeError(
592+
'You must enter the GraphRun as a contextmanager before you can iterate over it.'
593+
)
594+
return await self.next(self._next_node)

tests/graph/test_graph.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: #
5757
assert my_graph.name is None
5858
assert my_graph._get_state_type() is type(None)
5959
assert my_graph._get_run_end_type() is int
60-
result, history = await my_graph.run(Float2String(3.14))
60+
graph_run = await my_graph.run(Float2String(3.14))
6161
# len('3.14') * 2 == 8
62-
assert result == 8
62+
assert graph_run.result == 8
6363
assert my_graph.name == 'my_graph'
64-
assert history == snapshot(
64+
assert graph_run.history == snapshot(
6565
[
6666
NodeStep(
6767
state=None,
@@ -84,10 +84,10 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: #
8484
EndStep(result=End(data=8), ts=IsNow(tz=timezone.utc)),
8585
]
8686
)
87-
result, history = await my_graph.run(Float2String(3.14159))
87+
graph_run = await my_graph.run(Float2String(3.14159))
8888
# len('3.14159') == 7, 21 * 2 == 42
89-
assert result == 42
90-
assert history == snapshot(
89+
assert graph_run.result == 42
90+
assert graph_run.history == snapshot(
9191
[
9292
NodeStep(
9393
state=None,
@@ -122,7 +122,7 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: #
122122
EndStep(result=End(data=42), ts=IsNow(tz=timezone.utc)),
123123
]
124124
)
125-
assert [e.data_snapshot() for e in history] == snapshot(
125+
assert [e.data_snapshot() for e in graph_run.history] == snapshot(
126126
[
127127
Float2String(input_data=3.14159),
128128
String2Length(input_data='3.14159'),
@@ -320,10 +320,10 @@ async def run(self, ctx: GraphRunContext[None, Deps]) -> End[int]:
320320
return End(123)
321321

322322
g = Graph(nodes=(Foo, Bar))
323-
result, history = await g.run(Foo(), deps=Deps(1, 2))
323+
graph_run = await g.run(Foo(), deps=Deps(1, 2))
324324

325-
assert result == 123
326-
assert history == snapshot(
325+
assert graph_run.result == 123
326+
assert graph_run.history == snapshot(
327327
[
328328
NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()),
329329
NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()),

tests/graph/test_history.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,16 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[int]:
4646
],
4747
)
4848
async def test_dump_load_history(graph: Graph[MyState, None, int]):
49-
result, history = await graph.run(Foo(), state=MyState(1, ''))
50-
assert result == snapshot(4)
51-
assert history == snapshot(
49+
graph_run = await graph.run(Foo(), state=MyState(1, ''))
50+
assert graph_run.result == snapshot(4)
51+
assert graph_run.history == snapshot(
5252
[
5353
NodeStep(state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()),
5454
NodeStep(state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()),
5555
EndStep(result=End(4), ts=IsNow(tz=timezone.utc)),
5656
]
5757
)
58-
history_json = graph.dump_history(history)
58+
history_json = graph.dump_history(graph_run.history)
5959
assert json.loads(history_json) == snapshot(
6060
[
6161
{
@@ -76,7 +76,7 @@ async def test_dump_load_history(graph: Graph[MyState, None, int]):
7676
]
7777
)
7878
history_loaded = graph.load_history(history_json)
79-
assert history == history_loaded
79+
assert graph_run.history == history_loaded
8080

8181
custom_history = [
8282
{

tests/graph/test_mermaid.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ async def run(self, ctx: GraphRunContext) -> Annotated[End[None], Edge(label='eg
5858

5959

6060
async def test_run_graph():
61-
result, history = await graph1.run(Foo())
62-
assert result is None
63-
assert history == snapshot(
61+
graph_run = await graph1.run(Foo())
62+
assert graph_run.result is None
63+
assert graph_run.history == snapshot(
6464
[
6565
NodeStep(
6666
state=None,

tests/graph/test_state.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[str]:
3636
assert graph._get_state_type() is MyState
3737
assert graph._get_run_end_type() is str
3838
state = MyState(1, '')
39-
result, history = await graph.run(Foo(), state=state)
40-
assert result == snapshot('x=2 y=y')
41-
assert history == snapshot(
39+
graph_run = await graph.run(Foo(), state=state)
40+
assert graph_run.result == snapshot('x=2 y=y')
41+
assert graph_run.history == snapshot(
4242
[
4343
NodeStep(
4444
state=MyState(x=2, y=''),

tests/typed_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,6 @@ def run_g5() -> None:
109109
g5.run_sync(A()) # pyright: ignore[reportArgumentType]
110110
g5.run_sync(A(), state=MyState(x=1)) # pyright: ignore[reportArgumentType]
111111
g5.run_sync(A(), deps=MyDeps(y='y')) # pyright: ignore[reportArgumentType]
112-
answer, history = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y'))
113-
assert_type(answer, int)
114-
assert_type(history, list[HistoryStep[MyState, int]])
112+
graph_run = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y'))
113+
assert_type(graph_run.result, int)
114+
assert_type(graph_run.history, list[HistoryStep[MyState, int]])

0 commit comments

Comments
 (0)