Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GraphRun object to make use of next more ergonomic #833

Draft
wants to merge 1 commit into
base: dmontagu/stream-tool-calls
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 65 additions & 11 deletions pydantic_graph/pydantic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import inspect
import types
from collections.abc import Sequence
from collections.abc import AsyncGenerator, Sequence
from contextlib import ExitStack
from dataclasses import dataclass, field
from functools import cached_property
Expand Down Expand Up @@ -170,7 +170,7 @@ async def main():
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())

history: list[HistoryStep[StateT, T]] = []
graph_run = GraphRun[StateT, DepsT, T](self, state=state, deps=deps)
with ExitStack() as stack:
run_span: logfire_api.LogfireSpan | None = None
if self._auto_instrument:
Expand All @@ -184,19 +184,12 @@ async def main():

next_node = start_node
while True:
next_node = await self.next(next_node, history, state=state, deps=deps, infer_name=False)
next_node = await graph_run.next(next_node)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The simplicity of this loop is the primary reason I want to have a graph_run object — it just feels like there's so much less cruft.

if isinstance(next_node, End):
history.append(EndStep(result=next_node))
history = graph_run.history
if run_span is not None:
run_span.set_attribute('history', history)
return next_node.data, history
elif not isinstance(next_node, BaseNode):
if TYPE_CHECKING:
typing_extensions.assert_never(next_node)
else:
raise exceptions.GraphRuntimeError(
f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.'
)

def run_sync(
self: Graph[StateT, DepsT, T],
Expand Down Expand Up @@ -510,3 +503,64 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None:
if item is self:
self.name = name
return


class GraphRun(Generic[StateT, DepsT, RunEndT]):
def __init__(
self,
graph: Graph[StateT, DepsT, RunEndT],
*,
state: StateT = None,
deps: DepsT = None,
):
self.graph = graph
self.state = state
self.deps = deps

self.history: list[HistoryStep[StateT, RunEndT]] = []
self.final_result: End[RunEndT] | None = None

self._agen: (
AsyncGenerator[BaseNode[StateT, DepsT, RunEndT] | End[RunEndT], BaseNode[StateT, DepsT, RunEndT]] | None
) = None

async def next(
self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T]
) -> BaseNode[StateT, DepsT, Any] | End[T]:
agen = await self._get_primed_agen()
return await agen.asend(node)

async def _get_primed_agen(
self: GraphRun[StateT, DepsT, T],
) -> AsyncGenerator[BaseNode[StateT, DepsT, T] | End[T], BaseNode[StateT, DepsT, T]]:
graph = self.graph
state = self.state
deps = self.deps
history = self.history

if self._agen is None:

async def _agen() -> AsyncGenerator[BaseNode[StateT, DepsT, T] | End[T], BaseNode[StateT, DepsT, T]]:
next_node = yield # pyright: ignore[reportReturnType] # we prime the generator immediately below
while True:
next_node = await graph.next(next_node, history, state=state, deps=deps, infer_name=False)
if isinstance(next_node, End):
history.append(EndStep(result=next_node))
self.final_result = next_node
yield next_node
return
elif isinstance(next_node, BaseNode):
next_node = yield next_node # Give user a chance to modify the next node
else:
if TYPE_CHECKING:
typing_extensions.assert_never(next_node)
else:
raise exceptions.GraphRuntimeError(
f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.'
)

agen = _agen()
await agen.__anext__() # prime the generator

self._agen = agen
return self._agen
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,4 @@ skip = '.git*,*.svg,*.lock,*.css'
check-hidden = true
# Ignore "formatting" like **L**anguage
ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b'
# ignore-words-list = ''
ignore-words-list = 'asend'
Loading