Skip to content

Commit f7aa5fc

Browse files
JPPhotolstein
andauthored
Add chaining to Collect node (#8933)
* Add chained collect node * test(frontend): align parseSchema fixtures with collect v1.1 and normalize undefined fields in assertions * fix(nodes): block collect-to-collect links when inferred item types differ --------- Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
1 parent 438515b commit f7aa5fc

11 files changed

Lines changed: 438 additions & 68 deletions

File tree

invokeai/app/services/shared/graph.py

Lines changed: 150 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,23 @@ def is_any(t: Any) -> bool:
123123
return t == Any or Any in get_args(t)
124124

125125

126+
def extract_collection_item_types(t: Any) -> set[Any]:
127+
"""Extracts list item types from a collection annotation, including unions containing list branches."""
128+
if is_any(t):
129+
return {Any}
130+
131+
if get_origin(t) is list:
132+
return {arg for arg in get_args(t) if arg != NoneType}
133+
134+
item_types: set[Any] = set()
135+
for arg in get_args(t):
136+
if is_any(arg):
137+
item_types.add(Any)
138+
elif get_origin(arg) is list:
139+
item_types.update(item_arg for item_arg in get_args(arg) if item_arg != NoneType)
140+
return item_types
141+
142+
126143
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
127144
if not from_type or not to_type:
128145
return False
@@ -280,7 +297,7 @@ class CollectInvocationOutput(BaseInvocationOutput):
280297
)
281298

282299

283-
@invocation("collect", version="1.0.0")
300+
@invocation("collect", version="1.1.0")
284301
class CollectInvocation(BaseInvocation):
285302
"""Collects values into a collection"""
286303

@@ -292,7 +309,10 @@ class CollectInvocation(BaseInvocation):
292309
input=Input.Connection,
293310
)
294311
collection: list[Any] = InputField(
295-
description="The collection, will be provided on execution", default=[], ui_hidden=True
312+
description="An optional collection to append to",
313+
default=[],
314+
ui_type=UIType._Collection,
315+
input=Input.Connection,
296316
)
297317

298318
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
@@ -520,7 +540,9 @@ def _validate_edge(self, edge: Edge):
520540

521541
# Validate that an edge to this node+field doesn't already exist
522542
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
523-
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
543+
if len(input_edges) > 0 and (
544+
not isinstance(to_node, CollectInvocation) or edge.destination.field != ITEM_FIELD
545+
):
524546
raise InvalidEdgeError(f"Edge already exists ({edge})")
525547

526548
# Validate that no cycles would be created
@@ -546,8 +568,10 @@ def _validate_edge(self, edge: Edge):
546568
raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}")
547569

548570
# Validate if collector input type matches output type (if this edge results in both being set)
549-
if isinstance(to_node, CollectInvocation) and edge.destination.field == ITEM_FIELD:
550-
err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source)
571+
if isinstance(to_node, CollectInvocation) and edge.destination.field in (ITEM_FIELD, COLLECTION_FIELD):
572+
err = self._is_collector_connection_valid(
573+
edge.destination.node_id, new_input=edge.source, new_input_field=edge.destination.field
574+
)
551575
if err is not None:
552576
raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")
553577

@@ -676,76 +700,152 @@ def _is_iterator_connection_valid(
676700

677701
# Collector input type must match all iterator output types
678702
if isinstance(input_node, CollectInvocation):
679-
collector_inputs = self._get_input_edges(input_node.id, ITEM_FIELD)
680-
if len(collector_inputs) == 0:
681-
return "Iterator input collector must have at least one item input edge"
682-
683-
# Traverse the graph to find the first collector input edge. Collectors validate that their collection
684-
# inputs are all of the same type, so we can use the first input edge to determine the collector's type
685-
first_collector_input_edge = collector_inputs[0]
686-
first_collector_input_type = get_output_field_type(
687-
self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
688-
)
689-
resolved_collector_type = (
690-
first_collector_input_type
691-
if get_origin(first_collector_input_type) is None
692-
else get_args(first_collector_input_type)
693-
)
694-
if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)):
703+
input_root_type = self._get_collector_input_root_type(input_node.id)
704+
if input_root_type is None:
705+
return "Iterator input collector must have at least one item or collection input edge"
706+
if not all((are_connection_types_compatible(input_root_type, t) for t in output_field_types)):
695707
return "Iterator collection type must match all iterator output types"
696708

697709
return None
698710

711+
def _resolve_collector_input_types(self, node_id: str, visited: Optional[set[str]] = None) -> set[Any]:
712+
"""Resolves possible item types for a collector's inputs, recursively following chained collectors."""
713+
visited = visited or set()
714+
if node_id in visited:
715+
return set()
716+
visited.add(node_id)
717+
718+
input_types: set[Any] = set()
719+
720+
for edge in self._get_input_edges(node_id, ITEM_FIELD):
721+
input_field_type = get_output_field_type(self.get_node(edge.source.node_id), edge.source.field)
722+
resolved_types = [input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type)
723+
input_types.update(t for t in resolved_types if t != NoneType)
724+
725+
for edge in self._get_input_edges(node_id, COLLECTION_FIELD):
726+
source_node = self.get_node(edge.source.node_id)
727+
if isinstance(source_node, CollectInvocation) and edge.source.field == COLLECTION_FIELD:
728+
input_types.update(self._resolve_collector_input_types(source_node.id, visited.copy()))
729+
continue
730+
731+
input_field_type = get_output_field_type(source_node, edge.source.field)
732+
input_types.update(extract_collection_item_types(input_field_type))
733+
734+
return input_types
735+
736+
def _get_collector_input_root_type(self, node_id: str) -> Any | None:
737+
input_types = self._resolve_collector_input_types(node_id)
738+
non_any_input_types = {t for t in input_types if t != Any}
739+
if len(non_any_input_types) == 0 and Any in input_types:
740+
return Any
741+
if len(non_any_input_types) == 0:
742+
return None
743+
744+
type_tree = nx.DiGraph()
745+
type_tree.add_nodes_from(non_any_input_types)
746+
type_tree.add_edges_from([e for e in itertools.permutations(non_any_input_types, 2) if issubclass(e[1], e[0])])
747+
type_degrees = type_tree.in_degree(type_tree.nodes)
748+
root_types = [t[0] for t in type_degrees if t[1] == 0] # type: ignore
749+
if len(root_types) != 1:
750+
return Any
751+
return root_types[0]
752+
699753
def _is_collector_connection_valid(
700754
self,
701755
node_id: str,
702756
new_input: Optional[EdgeConnection] = None,
757+
new_input_field: Optional[str] = None,
703758
new_output: Optional[EdgeConnection] = None,
704759
) -> str | None:
705-
inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)]
760+
item_inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)]
761+
collection_inputs = [e.source for e in self._get_input_edges(node_id, COLLECTION_FIELD)]
706762
outputs = [e.destination for e in self._get_output_edges(node_id, COLLECTION_FIELD)]
707763

708764
if new_input is not None:
709-
inputs.append(new_input)
765+
field = new_input_field or ITEM_FIELD
766+
if field == ITEM_FIELD:
767+
item_inputs.append(new_input)
768+
elif field == COLLECTION_FIELD:
769+
collection_inputs.append(new_input)
710770
if new_output is not None:
711771
outputs.append(new_output)
712772

713-
# Get input and output fields (the fields linked to the iterator's input/output)
714-
input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in inputs]
773+
if len(item_inputs) == 0 and len(collection_inputs) == 0:
774+
return "Collector must have at least one item or collection input edge"
775+
776+
# Get input and output fields (the fields linked to the collector's input/output)
777+
item_input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in item_inputs]
778+
collection_input_field_types = [
779+
get_output_field_type(self.get_node(e.node_id), e.field) for e in collection_inputs
780+
]
715781
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]
716782

783+
if not all((is_list_or_contains_list(t) or is_any(t) for t in collection_input_field_types)):
784+
return "Collector collection input must be a collection"
785+
717786
# Validate that all inputs are derived from or match a single type
718787
input_field_types = {
719788
resolved_type
720-
for input_field_type in input_field_types
789+
for input_field_type in item_input_field_types
721790
for resolved_type in (
722791
[input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type)
723792
)
724793
if resolved_type != NoneType
725794
} # Get unique types
795+
796+
for input_conn, input_field_type in zip(collection_inputs, collection_input_field_types, strict=False):
797+
source_node = self.get_node(input_conn.node_id)
798+
if isinstance(source_node, CollectInvocation) and input_conn.field == COLLECTION_FIELD:
799+
input_field_types.update(self._resolve_collector_input_types(source_node.id))
800+
continue
801+
input_field_types.update(extract_collection_item_types(input_field_type))
802+
803+
non_any_input_field_types = {t for t in input_field_types if t != Any}
726804
type_tree = nx.DiGraph()
727-
type_tree.add_nodes_from(input_field_types)
728-
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
805+
type_tree.add_nodes_from(non_any_input_field_types)
806+
type_tree.add_edges_from(
807+
[e for e in itertools.permutations(non_any_input_field_types, 2) if issubclass(e[1], e[0])]
808+
)
729809
type_degrees = type_tree.in_degree(type_tree.nodes)
730-
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
810+
root_types = [t[0] for t in type_degrees if t[1] == 0] # type: ignore
811+
if len(root_types) > 1:
731812
return "Collector input collection items must be of a single type"
732813

733-
# Get the input root type
734-
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
814+
# Get the input root type (if known)
815+
input_root_type = root_types[0] if len(root_types) == 1 else None
735816

736817
# Verify that all outputs are lists
737818
if not all(is_list_or_contains_list(t) or is_any(t) for t in output_field_types):
738819
return "Collector output must connect to a collection input"
739820

740821
# Verify that all outputs match the input type (are a base class or the same class)
741-
if not all(
742-
is_any(t)
743-
or is_union_subtype(input_root_type, get_args(t)[0])
744-
or issubclass(input_root_type, get_args(t)[0])
745-
for t in output_field_types
746-
):
822+
if input_root_type is not None:
823+
if not all(
824+
is_any(t)
825+
or is_union_subtype(input_root_type, get_args(t)[0])
826+
or issubclass(input_root_type, get_args(t)[0])
827+
for t in output_field_types
828+
):
829+
return "Collector outputs must connect to a collection input with a matching type"
830+
elif any(not is_any(t) and get_args(t)[0] != Any for t in output_field_types):
747831
return "Collector outputs must connect to a collection input with a matching type"
748832

833+
# If this collector outputs to another collector's collection input, validate against the downstream
834+
# collector's resolved input type (if available).
835+
for output in outputs:
836+
output_node = self.get_node(output.node_id)
837+
if not isinstance(output_node, CollectInvocation) or output.field != COLLECTION_FIELD:
838+
continue
839+
output_root_type = self._get_collector_input_root_type(output_node.id)
840+
if output_root_type is None:
841+
continue
842+
if input_root_type is None:
843+
if output_root_type != Any:
844+
return "Collector outputs must connect to a collection input with a matching type"
845+
continue
846+
if not are_connection_types_compatible(input_root_type, output_root_type):
847+
return "Collector outputs must connect to a collection input with a matching type"
848+
749849
return None
750850

751851
def nx_graph(self) -> nx.DiGraph:
@@ -1211,8 +1311,19 @@ def _prepare_inputs(self, node: BaseInvocation):
12111311
if isinstance(node, CollectInvocation):
12121312
item_edges = [e for e in input_edges if e.destination.field == ITEM_FIELD]
12131313
item_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id))
1214-
1215-
output_collection = [copydeep(getattr(self.results[e.source.node_id], e.source.field)) for e in item_edges]
1314+
collection_edges = [e for e in input_edges if e.destination.field == COLLECTION_FIELD]
1315+
collection_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id))
1316+
1317+
output_collection = []
1318+
for edge in collection_edges:
1319+
source_value = copydeep(getattr(self.results[edge.source.node_id], edge.source.field))
1320+
if isinstance(source_value, list):
1321+
output_collection.extend(source_value)
1322+
else:
1323+
output_collection.append(source_value)
1324+
output_collection.extend(
1325+
copydeep(getattr(self.results[e.source.node_id], e.source.field)) for e in item_edges
1326+
)
12161327
node.collection = output_collection
12171328
else:
12181329
for edge in input_edges:

invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,14 @@ describe(getCollectItemType.name, () => {
4141
const result = getCollectItemType({ add: addWithoutOutputValue, collect }, [n2, n1], [e1], n1.id);
4242
expect(result).toBeNull();
4343
});
44+
45+
it('should return the upstream collect item type for chained collects', () => {
46+
const n1 = buildNode(collect);
47+
const n2 = buildNode(collect);
48+
const n3 = buildNode(add);
49+
const e1 = buildEdge(n3.id, 'value', n1.id, 'item');
50+
const e2 = buildEdge(n1.id, 'collection', n2.id, 'collection');
51+
const result = getCollectItemType(templates, [n1, n2, n3], [e1, e2], n2.id);
52+
expect(result).toEqual<FieldType>({ name: 'IntegerField', cardinality: 'SINGLE', batch: false });
53+
});
4454
});

invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@ import type { Templates } from 'features/nodes/store/types';
22
import type { FieldType } from 'features/nodes/types/field';
33
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
44

5+
const toItemType = (fieldType: FieldType): FieldType | null => {
6+
if (fieldType.name === 'CollectionField') {
7+
return null;
8+
}
9+
if (fieldType.cardinality === 'COLLECTION' || fieldType.cardinality === 'SINGLE_OR_COLLECTION') {
10+
return { ...fieldType, cardinality: 'SINGLE' };
11+
}
12+
return fieldType;
13+
};
14+
515
/**
616
* Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and
717
* field connected to the collector's `item` input. The field type of that field is returned, else null if there is no
@@ -18,21 +28,56 @@ export const getCollectItemType = (
1828
edges: AnyEdge[],
1929
nodeId: string
2030
): FieldType | null => {
21-
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
22-
if (!firstEdgeToCollect?.sourceHandle) {
23-
return null;
24-
}
25-
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
26-
if (!node) {
27-
return null;
28-
}
29-
const template = templates[node.data.type];
30-
if (!template) {
31-
return null;
32-
}
33-
const fieldTemplate = template.outputs[firstEdgeToCollect.sourceHandle];
34-
if (!fieldTemplate) {
31+
const getCollectItemTypeInternal = (currentNodeId: string, visited: Set<string>): FieldType | null => {
32+
if (visited.has(currentNodeId)) {
33+
return null;
34+
}
35+
visited.add(currentNodeId);
36+
37+
const firstItemEdgeToCollect = edges.find((edge) => edge.target === currentNodeId && edge.targetHandle === 'item');
38+
if (firstItemEdgeToCollect?.sourceHandle) {
39+
const node = nodes.find((n) => n.id === firstItemEdgeToCollect.source);
40+
if (!node) {
41+
return null;
42+
}
43+
const template = templates[node.data.type];
44+
if (!template) {
45+
return null;
46+
}
47+
const fieldTemplate = template.outputs[firstItemEdgeToCollect.sourceHandle];
48+
if (!fieldTemplate) {
49+
return null;
50+
}
51+
return toItemType(fieldTemplate.type);
52+
}
53+
54+
const firstCollectionEdgeToCollect = edges.find(
55+
(edge) => edge.target === currentNodeId && edge.targetHandle === 'collection'
56+
);
57+
if (!firstCollectionEdgeToCollect?.sourceHandle) {
58+
return null;
59+
}
60+
const sourceNode = nodes.find((n) => n.id === firstCollectionEdgeToCollect.source);
61+
if (!sourceNode) {
62+
return null;
63+
}
64+
if (sourceNode.data.type === 'collect' && firstCollectionEdgeToCollect.sourceHandle === 'collection') {
65+
return getCollectItemTypeInternal(sourceNode.id, visited);
66+
}
67+
const sourceTemplate = templates[sourceNode.data.type];
68+
if (!sourceTemplate) {
69+
return null;
70+
}
71+
const sourceFieldTemplate = sourceTemplate.outputs[firstCollectionEdgeToCollect.sourceHandle];
72+
if (!sourceFieldTemplate) {
73+
return null;
74+
}
75+
return toItemType(sourceFieldTemplate.type);
76+
};
77+
78+
const itemType = getCollectItemTypeInternal(nodeId, new Set());
79+
if (!itemType) {
3580
return null;
3681
}
37-
return fieldTemplate.type;
82+
return itemType;
3883
};

0 commit comments

Comments
 (0)