Skip to content

Commit 07e160a

Browse files
committed
Fix graph execution state resume after JSON round-trip
1 parent 33ec16d commit 07e160a

2 files changed

Lines changed: 93 additions & 0 deletions

File tree

invokeai/app/services/shared/graph.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1819,6 +1819,67 @@ def _prepare_until_node_ready(self) -> Optional[BaseInvocation]:
18191819

18201820
return next_node
18211821

1822+
def _reset_runtime_caches(self) -> None:
1823+
self._ready_queues = {}
1824+
self._active_class = None
1825+
self._iteration_path_cache = {}
1826+
self._if_branch_exclusive_sources = {}
1827+
self._resolved_if_exec_branches = {}
1828+
self._prepared_exec_metadata = {}
1829+
self._prepared_exec_registry = None
1830+
self._if_branch_scheduler = None
1831+
self._execution_materializer = None
1832+
self._execution_scheduler = None
1833+
self._execution_runtime = None
1834+
1835+
def _rehydrate_prepared_exec_metadata(self) -> None:
1836+
registry = self._prepared_registry()
1837+
for exec_node_id, source_node_id in self.prepared_source_mapping.items():
1838+
metadata = registry.get_metadata(exec_node_id)
1839+
metadata.source_node_id = source_node_id
1840+
if exec_node_id in self.executed:
1841+
metadata.state = "executed" if exec_node_id in self.results else "skipped"
1842+
elif self.indegree.get(exec_node_id) == 0:
1843+
metadata.state = "ready"
1844+
else:
1845+
metadata.state = "pending"
1846+
1847+
def _rehydrate_resolved_if_exec_branches(self) -> None:
1848+
for exec_node_id, node in self.execution_graph.nodes.items():
1849+
if not isinstance(node, IfInvocation):
1850+
continue
1851+
1852+
condition_edges = self.execution_graph._get_input_edges(exec_node_id, "condition")
1853+
if any(edge.source.node_id not in self.executed for edge in condition_edges):
1854+
continue
1855+
1856+
for edge in condition_edges:
1857+
setattr(
1858+
node,
1859+
edge.destination.field,
1860+
copydeep(getattr(self.results[edge.source.node_id], edge.source.field)),
1861+
)
1862+
1863+
self._resolved_if_exec_branches[exec_node_id] = "true_input" if node.condition else "false_input"
1864+
1865+
def _rehydrate_ready_queues(self) -> None:
1866+
execution_graph = self.execution_graph.nx_graph_flat()
1867+
for exec_node_id in nx.topological_sort(execution_graph):
1868+
if exec_node_id in self.executed:
1869+
continue
1870+
if self.indegree.get(exec_node_id) != 0:
1871+
continue
1872+
self._enqueue_if_ready(exec_node_id)
1873+
1874+
def _rehydrate_runtime_state(self) -> None:
1875+
self._reset_runtime_caches()
1876+
self._rehydrate_prepared_exec_metadata()
1877+
self._rehydrate_resolved_if_exec_branches()
1878+
self._rehydrate_ready_queues()
1879+
1880+
def model_post_init(self, __context: Any) -> None:
1881+
self._rehydrate_runtime_state()
1882+
18221883
model_config = ConfigDict(
18231884
json_schema_extra={
18241885
"required": [

tests/test_graph_execution_state.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import Mock
33

44
import pytest
5+
from pydantic import TypeAdapter
56

67
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
78
from invokeai.app.invocations.collections import RangeInvocation
@@ -137,6 +138,37 @@ def test_graph_state_collects():
137138
assert sorted(g.results[n6[0].id].collection) == sorted(test_prompts)
138139

139140

141+
def test_graph_state_resumes_partially_executed_session_after_json_round_trip():
142+
graph = Graph()
143+
graph.add_node(RangeInvocation(id="c", start=1, stop=5, step=1))
144+
graph.add_node(IterateInvocation(id="iter"))
145+
graph.add_node(AddInvocation(id="add", b=1))
146+
graph.add_node(CollectInvocation(id="collect"))
147+
148+
graph.add_edge(create_edge("c", "collection", "iter", "collection"))
149+
graph.add_edge(create_edge("iter", "item", "add", "a"))
150+
graph.add_edge(create_edge("add", "value", "collect", "item"))
151+
152+
state = GraphExecutionState(graph=graph)
153+
154+
for _ in range(4):
155+
invocation, output = invoke_next(state)
156+
assert invocation is not None
157+
assert output is not None
158+
159+
raw = state.model_dump_json(warnings=False, exclude_none=True)
160+
resumed = TypeAdapter(GraphExecutionState).validate_json(raw, strict=False)
161+
162+
executed_source_ids = execute_all_nodes(resumed)
163+
164+
assert executed_source_ids
165+
assert "add" in executed_source_ids
166+
assert "collect" in resumed.source_prepared_mapping
167+
168+
prepared_collect_id = next(iter(resumed.source_prepared_mapping["collect"]))
169+
assert resumed.results[prepared_collect_id].collection == [2, 3, 4, 5]
170+
171+
140172
def test_graph_state_prepares_eagerly():
141173
"""Tests that all prepareable nodes are prepared"""
142174
graph = Graph()

0 commit comments

Comments
 (0)