3030 logfire ._internal .stack_info .NON_USER_CODE_PREFIXES += (str (Path (__file__ ).parent .absolute ()),)
3131
3232
33- __all__ = ('Graph' , 'GraphRun' , 'GraphRunner' )
33+ __all__ = ('Graph' , 'GraphRun' )
3434
3535_logfire = logfire_api .Logfire (otel_scope = 'pydantic-graph' )
3636
@@ -133,7 +133,7 @@ def run(
133133 state : StateT = None ,
134134 deps : DepsT = None ,
135135 infer_name : bool = True ,
136- ) -> GraphRunner [StateT , DepsT , T ]:
136+ ) -> GraphRun [StateT , DepsT , T ]:
137137 """Run the graph from a starting node until it ends.
138138
139139 Args:
@@ -170,7 +170,7 @@ async def main():
170170 if infer_name and self .name is None :
171171 self ._infer_name (inspect .currentframe ())
172172
173- return GraphRunner [StateT , DepsT , T ](
173+ return GraphRun [StateT , DepsT , T ](
174174 self , start_node , history = [], state = state , deps = deps , auto_instrument = self ._auto_instrument
175175 )
176176
@@ -181,7 +181,7 @@ def run_sync(
181181 state : StateT = None ,
182182 deps : DepsT = None ,
183183 infer_name : bool = True ,
184- ) -> tuple [ T , list [ HistoryStep [ StateT , T ]] ]:
184+ ) -> GraphRun [ StateT , DepsT , T ]:
185185 """Run the graph synchronously.
186186
187187 This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`.
@@ -499,11 +499,10 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None:
499499 return
500500
501501
502- class GraphRunner (Generic [StateT , DepsT , RunEndT ]):
503- """An object that can be awaited to perform a graph run .
502+ class GraphRun (Generic [StateT , DepsT , RunEndT ]):
503+ """A stateful run of a graph.
504504
505- This object can also be used as a contextmanager to get a handle to a specific graph run,
506- allowing you to iterate over nodes, and possibly perform modifications to the nodes as they are run.
505+ After being entered, can be used like an async generator to listen to / modify nodes as the run is executed.
507506 """
508507
509508 def __init__ (
@@ -517,84 +516,25 @@ def __init__(
517516 auto_instrument : bool ,
518517 ):
519518 self .graph = graph
520- self .first_node = first_node
521519 self .history = history
522520 self .state = state
523521 self .deps = deps
524-
525- self ._run : GraphRun [StateT , DepsT , RunEndT ] | None = None
526-
527522 self ._auto_instrument = auto_instrument
528- self ._span : LogfireSpan | None = None
529-
530- @property
531- def run (self ) -> GraphRun [StateT , DepsT , RunEndT ]:
532- if self ._run is None :
533- raise exceptions .GraphRuntimeError ('GraphRunner has not been awaited yet.' )
534- return self ._run
535-
536- def __await__ (self ) -> Generator [Any , Any , tuple [RunEndT , list [HistoryStep [StateT , RunEndT ]]]]:
537- """Run the graph until it ends, and return the final result."""
538-
539- async def _run () -> tuple [RunEndT , list [HistoryStep [StateT , RunEndT ]]]:
540- async with self as run :
541- self ._run = run
542- async for _next_node in run :
543- pass
544-
545- return run .final_result , run .history
546-
547- return _run ().__await__ ()
548-
549- async def __aenter__ (self ) -> GraphRun [StateT , DepsT , RunEndT ]:
550- if self ._run is not None :
551- raise exceptions .GraphRuntimeError ('A GraphRunner can only start a GraphRun once.' )
552-
553- if self ._auto_instrument :
554- self ._span = logfire_api .span ('run graph {graph.name}' , graph = self .graph )
555- self ._span .__enter__ ()
556-
557- self ._run = run = GraphRun (self .graph , self .first_node , history = self .history , state = self .state , deps = self .deps )
558- return run
559-
560- async def __aexit__ (self , exc_type : Any , exc_val : Any , exc_tb : Any ) -> None :
561- if self ._span is not None :
562- self ._span .__exit__ (exc_type , exc_val , exc_tb )
563- self ._span = None # make it more obvious if you try to use it after exiting
564523
565-
566- class GraphRun (Generic [StateT , DepsT , RunEndT ]):
567- """A stateful run of a graph.
568-
569- Can be used like an async generator to listen to / modify nodes as the run is executed.
570- """
571-
572- def __init__ (
573- self ,
574- graph : Graph [StateT , DepsT , RunEndT ],
575- next_node : BaseNode [StateT , DepsT , RunEndT ],
576- * ,
577- history : list [HistoryStep [StateT , RunEndT ]],
578- state : StateT ,
579- deps : DepsT ,
580- ):
581- self .graph = graph
582- self .next_node = next_node
583- self .history = history
584- self .state = state
585- self .deps = deps
586-
587- self ._final_result : End [RunEndT ] | None = None
524+ self ._next_node = first_node
525+ self ._started : bool = False
526+ self ._result : End [RunEndT ] | None = None
527+ self ._span : LogfireSpan | None = None
588528
589529 @property
590530 def is_ended (self ):
591- return self ._final_result is not None
531+ return self ._result is not None
592532
593533 @property
594- def final_result (self ) -> RunEndT :
595- if self ._final_result is None :
534+ def result (self ) -> RunEndT :
535+ if self ._result is None :
596536 raise exceptions .GraphRuntimeError ('GraphRun has not ended yet.' )
597- return self ._final_result .data
537+ return self ._result .data
598538
599539 async def next (
600540 self : GraphRun [StateT , DepsT , T ], node : BaseNode [StateT , DepsT , T ]
@@ -607,16 +547,48 @@ async def next(
607547 next_node = await self .graph .next (node , history , state = state , deps = deps , infer_name = False )
608548
609549 if isinstance (next_node , End ):
610- self ._final_result = next_node
550+ self ._result = next_node
611551 else :
612- self .next_node = next_node
552+ self ._next_node = next_node
613553 return next_node
614554
555+ def __await__ (self ) -> Generator [Any , Any , typing_extensions .Self ]:
556+ """Run the graph until it ends, and return the final result."""
557+
558+ async def _run () -> typing_extensions .Self :
559+ with self :
560+ async for _next_node in self :
561+ pass
562+
563+ return self
564+
565+ return _run ().__await__ ()
566+
567+ def __enter__ (self ) -> typing_extensions .Self :
568+ if self ._started :
569+ raise exceptions .GraphRuntimeError ('A GraphRun can only be started once.' )
570+
571+ if self ._auto_instrument :
572+ self ._span = logfire_api .span ('run graph {graph.name}' , graph = self .graph )
573+ self ._span .__enter__ ()
574+
575+ self ._started = True
576+ return self
577+
578+ def __exit__ (self , exc_type : Any , exc_val : Any , exc_tb : Any ) -> None :
579+ if self ._span is not None :
580+ self ._span .__exit__ (exc_type , exc_val , exc_tb )
581+ self ._span = None # make it more obvious if you try to use it after exiting
582+
615583 def __aiter__ (self ) -> AsyncIterator [BaseNode [StateT , DepsT , RunEndT ] | End [RunEndT ]]:
616584 return self
617585
618586 async def __anext__ (self ) -> BaseNode [StateT , DepsT , RunEndT ] | End [RunEndT ]:
619587 """Use the last returned node as the input to `Graph.next`."""
620- if self ._final_result :
588+ if self ._result :
621589 raise StopAsyncIteration
622- return await self .next (self .next_node )
590+ if not self ._started :
591+ raise exceptions .GraphRuntimeError (
592+ 'You must enter the GraphRun as a contextmanager before you can iterate over it.'
593+ )
594+ return await self .next (self ._next_node )
0 commit comments