Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from threading import Event
from typing import Optional, Protocol
from typing import Any, Optional, Protocol

from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.invocation_services import InvocationServices
Expand Down Expand Up @@ -36,10 +36,13 @@ def run(self, queue_item: SessionQueueItem) -> None:
pass

@abstractmethod
def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem) -> None:
def run_node(
self, transient_storage: dict[str, Any], invocation: BaseInvocation, queue_item: SessionQueueItem
) -> None:
"""Run a single node in the graph.

Args:
transient_storage: Transient storage passed to each node that executes.
invocation: The invocation to run.
queue_item: The session queue item.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from contextlib import suppress
from threading import BoundedSemaphore, Thread
from threading import Event as ThreadEvent
from typing import Optional
from typing import Any, Optional

from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput
from invokeai.app.services.events.events_common import (
Expand Down Expand Up @@ -74,6 +74,8 @@ def run(self, queue_item: SessionQueueItem):

self._on_before_run_session(queue_item=queue_item)

transient_storage = {}

# Loop over invocations until the session is complete or canceled
while True:
try:
Expand All @@ -95,7 +97,7 @@ def run(self, queue_item: SessionQueueItem):
if invocation is None or self._is_canceled():
break

self.run_node(invocation, queue_item)
self.run_node(transient_storage, invocation, queue_item)

# The session is complete if all invocations have been run or there is an error on the session.
# At this time, the queue item may be canceled, but the object itself here won't be updated yet. We must
Expand All @@ -109,7 +111,7 @@ def run(self, queue_item: SessionQueueItem):

self._on_after_run_session(queue_item=queue_item)

def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
def run_node(self, transient_storage: dict[str, Any], invocation: BaseInvocation, queue_item: SessionQueueItem):
try:
# Any unhandled exception in this scope is an invocation error & will fail the graph
with self._services.performance_statistics.collect_stats(invocation, queue_item.session_id):
Expand All @@ -123,6 +125,7 @@ def run_node(self, invocation: BaseInvocation, queue_item: SessionQueueItem):
context = build_invocation_context(
data=data,
services=self._services,
transient_storage=transient_storage,
is_canceled=self._is_canceled,
)

Expand Down
9 changes: 8 additions & 1 deletion invokeai/app/services/shared/invocation_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@
Wrapping these services provides a simpler and safer interface for nodes to use.

When a node executes, a fresh `InvocationContext` is built for it, ensuring nodes cannot interfere
with each other.
with each other with the exception of shared per-graph transient_storage for intra-workflow node
communication.

Many of the wrappers have the same signature as the methods they wrap. This allows us to write
user-facing docstrings and not need to go and update the internal services to match.
Expand Down Expand Up @@ -732,6 +733,7 @@ def __init__(
config: ConfigInterface,
util: UtilInterface,
boards: BoardsInterface,
transient_storage: dict,
data: InvocationContextData,
services: InvocationServices,
) -> None:
Expand All @@ -751,6 +753,8 @@ def __init__(
"""Utility methods, including a method to check if an invocation was canceled and step callbacks."""
self.boards = boards
"""Methods to interact with boards."""
self.transient_storage = transient_storage
"""Transient storage for all nodes in this execution."""
self._data = data
"""An internal API providing access to data about the current queue item and invocation. You probably shouldn't use this. It may change without warning."""
self._services = services
Expand All @@ -760,13 +764,15 @@ def __init__(
def build_invocation_context(
services: InvocationServices,
data: InvocationContextData,
transient_storage: dict,
is_canceled: Callable[[], bool],
) -> InvocationContext:
"""Builds the invocation context for a specific invocation execution.

Args:
services: The invocation services to wrap.
data: The invocation context data.
transient_storage: Transient storage passed along to every executed node in this workflow.

Returns:
The invocation context.
Expand All @@ -792,6 +798,7 @@ def build_invocation_context(
conditioning=conditioning,
services=services,
boards=boards,
transient_storage=transient_storage,
)

return ctx
1 change: 1 addition & 0 deletions tests/app/services/model_load/test_load_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def mock_context(
services=mock_services,
data=InvocationContextData(queue_item=None, invocation=None, source_invocation_id=None), # type: ignore
is_canceled=None, # type: ignore
transient_storage={},
)


Expand Down
157 changes: 155 additions & 2 deletions tests/test_graph_execution_state.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
from typing import Optional
from unittest.mock import Mock

Expand Down Expand Up @@ -37,7 +38,7 @@ def invoke_next(g: GraphExecutionState) -> tuple[Optional[BaseInvocation], Optio
return (None, None)

print(f"invoking {n.id}: {type(n)}")
o = n.invoke(Mock(InvocationContext))
o = n.invoke(Mock(spec=InvocationContext))
g.complete(n.id, o)

return (n, o)
Expand Down Expand Up @@ -193,7 +194,7 @@ def assert_topo_order_and_all_executed(state: GraphExecutionState, order: list[s
n = g.next()
if n is None:
break
o = n.invoke(Mock(InvocationContext))
o = n.invoke(Mock(spec=InvocationContext))
g.complete(n.id, o)
order.append(n.id)

Expand Down Expand Up @@ -314,3 +315,155 @@ class Child(Base):
pass

assert are_connection_types_compatible(Child, Base) is True


def _make_invocation_context(transient_storage: Optional[dict] = None) -> InvocationContext:
"""
Best-effort constructor that adapts to InvocationContext's signature by supplying Mock()
for any required parameters.

If transient_storage is provided, it is reused (per-workflow lifetime).
"""
sig = inspect.signature(InvocationContext)
kwargs = {}
for name, param in sig.parameters.items():
if name == "self":
continue

if name == "transient_storage":
kwargs[name] = {} if transient_storage is None else transient_storage
continue

if param.default is not inspect._empty:
continue

kwargs[name] = Mock(name=f"InvocationContext.{name}")

ctx = InvocationContext(**kwargs)
assert isinstance(ctx.transient_storage, dict)
return ctx


def _ts_get(storage, key: str):
try:
return storage[key]
except Exception:
# KeyError for dict-like; some storages may raise different exceptions
return None


def _ts_contains(storage, key: str) -> bool:
try:
return key in storage
except Exception:
try:
storage[key]
return True
except Exception:
return False


def test_invocation_context_transient_storage_is_per_instance_and_starts_empty():
ctx1 = _make_invocation_context()
ctx2 = _make_invocation_context()

assert hasattr(ctx1, "transient_storage")
assert hasattr(ctx2, "transient_storage")

assert ctx1.transient_storage == {}
assert ctx2.transient_storage == {}

# Must not be shared across contexts (guards against class-level mutable default / singleton).
assert ctx1.transient_storage is not ctx2.transient_storage

# Must start empty for a new execution context.
assert not _ts_contains(ctx1.transient_storage, "__sentinel__")
assert not _ts_contains(ctx2.transient_storage, "__sentinel__")

ctx1.transient_storage["__sentinel__"] = "x"

# Must not appear in the second context.
assert not _ts_contains(ctx2.transient_storage, "__sentinel__")


def test_transient_storage_persists_within_one_graph_execution_and_resets_for_next():
"""
a) Persists across multiple node invocations during a single workflow execution.
b) Does not leak into a second workflow execution (fresh InvocationContext).
"""

# Local imports to avoid perturbing module import ordering.
from invokeai.app.invocations.baseinvocation import invocation, invocation_output
from invokeai.app.invocations.fields import InputField, OutputField

@invocation_output("transient_storage_test_output")
class _TransientStorageTestOutput(BaseInvocationOutput):
value: Optional[str] = OutputField(default=None)

@invocation(
"transient_storage_write_test",
title="Transient Storage Write Test",
tags=["test"],
version="1.0.0",
use_cache=False,
)
class _TransientStorageWriteTestInvocation(BaseInvocation):
key: str = InputField(default="k")
value: str = InputField(default="v")

def invoke(self, context: InvocationContext) -> _TransientStorageTestOutput:
context.transient_storage[self.key] = self.value
return _TransientStorageTestOutput(value=self.value)

@invocation(
"transient_storage_read_test",
title="Transient Storage Read Test",
tags=["test"],
version="1.0.0",
use_cache=False,
)
class _TransientStorageReadTestInvocation(BaseInvocation):
key: str = InputField(default="k")
trigger: str = InputField(default="") # only to enforce dependency

def invoke(self, context: InvocationContext) -> _TransientStorageTestOutput:
return _TransientStorageTestOutput(value=_ts_get(context.transient_storage, self.key))

def run_graph(graph: Graph, transient_storage: dict) -> list[tuple[BaseInvocation, BaseInvocationOutput]]:
state = GraphExecutionState(graph=graph)
out: list[tuple[BaseInvocation, BaseInvocationOutput]] = []
while True:
n = state.next()
if n is None:
break
# Mirror production: new InvocationContext per node, shared transient_storage per workflow
ctx = _make_invocation_context(transient_storage=transient_storage)
o = n.invoke(ctx)
state.complete(n.id, o)
out.append((n, o))
assert state.is_complete()
return out

# Execution 1: write then read in the same workflow should succeed.
g1 = Graph()
g1.add_node(_TransientStorageWriteTestInvocation(id="write", key="k", value="hello"))
g1.add_node(_TransientStorageReadTestInvocation(id="read", key="k"))
g1.add_edge(create_edge("write", "value", "read", "trigger"))

ts1: dict = {}
trace1 = run_graph(g1, ts1)

read_outputs_1 = [o for (n, o) in trace1 if isinstance(n, _TransientStorageReadTestInvocation)]
assert len(read_outputs_1) == 1
assert read_outputs_1[0].value == "hello"

# Execution 2: read-only graph with a fresh InvocationContext must not see prior state.
g2 = Graph()
g2.add_node(_TransientStorageReadTestInvocation(id="read_only", key="k"))

ts2: dict = {}
trace2 = run_graph(g2, ts2)

read_outputs_2 = [o for (n, o) in trace2 if isinstance(n, _TransientStorageReadTestInvocation)]
assert len(read_outputs_2) == 1
assert read_outputs_2[0].value is None