From 04fc74cdd514344c7dfb9a842e993d729c18f923 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Fri, 31 Jan 2025 18:08:36 -0700 Subject: [PATCH] Add GraphRun object --- pydantic_graph/pydantic_graph/graph.py | 76 ++++++++++++++++++++++---- pyproject.toml | 2 +- 2 files changed, 66 insertions(+), 12 deletions(-) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index a670c3d3..1206fe52 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -3,7 +3,7 @@ import asyncio import inspect import types -from collections.abc import Sequence +from collections.abc import AsyncGenerator, Sequence from contextlib import ExitStack from dataclasses import dataclass, field from functools import cached_property @@ -170,7 +170,7 @@ async def main(): if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - history: list[HistoryStep[StateT, T]] = [] + graph_run = GraphRun[StateT, DepsT, T](self, state=state, deps=deps) with ExitStack() as stack: run_span: logfire_api.LogfireSpan | None = None if self._auto_instrument: @@ -184,19 +184,12 @@ async def main(): next_node = start_node while True: - next_node = await self.next(next_node, history, state=state, deps=deps, infer_name=False) + next_node = await graph_run.next(next_node) if isinstance(next_node, End): - history.append(EndStep(result=next_node)) + history = graph_run.history if run_span is not None: run_span.set_attribute('history', history) return next_node.data, history - elif not isinstance(next_node, BaseNode): - if TYPE_CHECKING: - typing_extensions.assert_never(next_node) - else: - raise exceptions.GraphRuntimeError( - f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' - ) def run_sync( self: Graph[StateT, DepsT, T], @@ -510,3 +503,64 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None: if item is self: self.name = name return + + +class GraphRun(Generic[StateT, DepsT, RunEndT]): + def __init__( + self, + graph: Graph[StateT, DepsT, RunEndT], + *, + state: StateT = None, + deps: DepsT = None, + ): + self.graph = graph + self.state = state + self.deps = deps + + self.history: list[HistoryStep[StateT, RunEndT]] = [] + self.final_result: End[RunEndT] | None = None + + self._agen: ( + AsyncGenerator[BaseNode[StateT, DepsT, RunEndT] | End[RunEndT], BaseNode[StateT, DepsT, RunEndT]] | None + ) = None + + async def next( + self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T] + ) -> BaseNode[StateT, DepsT, Any] | End[T]: + agen = await self._get_primed_agen() + return await agen.asend(node) + + async def _get_primed_agen( + self: GraphRun[StateT, DepsT, T], + ) -> AsyncGenerator[BaseNode[StateT, DepsT, T] | End[T], BaseNode[StateT, DepsT, T]]: + graph = self.graph + state = self.state + deps = self.deps + history = self.history + + if self._agen is None: + + async def _agen() -> AsyncGenerator[BaseNode[StateT, DepsT, T] | End[T], BaseNode[StateT, DepsT, T]]: + next_node = yield # pyright: ignore[reportReturnType] # we prime the generator immediately below + while True: + next_node = await graph.next(next_node, history, state=state, deps=deps, infer_name=False) + if isinstance(next_node, End): + history.append(EndStep(result=next_node)) + self.final_result = next_node + yield next_node + return + elif isinstance(next_node, BaseNode): + next_node = yield next_node # Give user a chance to modify the next node + else: + if TYPE_CHECKING: + typing_extensions.assert_never(next_node) + else: + raise exceptions.GraphRuntimeError( + f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.' + ) + + agen = _agen() + await agen.__anext__() # prime the generator + + self._agen = agen + return self._agen diff --git a/pyproject.toml b/pyproject.toml index e0acff80..e2c7d029 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -193,4 +193,4 @@ skip = '.git*,*.svg,*.lock,*.css' check-hidden = true # Ignore "formatting" like **L**anguage ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b' -# ignore-words-list = '' + ignore-words-list = 'asend'