Skip to content

Commit bfe1cb0

Browse files
authored
Merge branch 'invoke-ai:main' into dype-paper
2 parents fdc2889 + f7aa5fc commit bfe1cb0

13 files changed

Lines changed: 442 additions & 72 deletions

File tree

invokeai/app/services/shared/graph.py

Lines changed: 150 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,23 @@ def is_any(t: Any) -> bool:
123123
return t == Any or Any in get_args(t)
124124

125125

126+
def extract_collection_item_types(t: Any) -> set[Any]:
127+
"""Extracts list item types from a collection annotation, including unions containing list branches."""
128+
if is_any(t):
129+
return {Any}
130+
131+
if get_origin(t) is list:
132+
return {arg for arg in get_args(t) if arg != NoneType}
133+
134+
item_types: set[Any] = set()
135+
for arg in get_args(t):
136+
if is_any(arg):
137+
item_types.add(Any)
138+
elif get_origin(arg) is list:
139+
item_types.update(item_arg for item_arg in get_args(arg) if item_arg != NoneType)
140+
return item_types
141+
142+
126143
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
127144
if not from_type or not to_type:
128145
return False
@@ -280,7 +297,7 @@ class CollectInvocationOutput(BaseInvocationOutput):
280297
)
281298

282299

283-
@invocation("collect", version="1.0.0")
300+
@invocation("collect", version="1.1.0")
284301
class CollectInvocation(BaseInvocation):
285302
"""Collects values into a collection"""
286303

@@ -292,7 +309,10 @@ class CollectInvocation(BaseInvocation):
292309
input=Input.Connection,
293310
)
294311
collection: list[Any] = InputField(
295-
description="The collection, will be provided on execution", default=[], ui_hidden=True
312+
description="An optional collection to append to",
313+
default=[],
314+
ui_type=UIType._Collection,
315+
input=Input.Connection,
296316
)
297317

298318
def invoke(self, context: InvocationContext) -> CollectInvocationOutput:
@@ -520,7 +540,9 @@ def _validate_edge(self, edge: Edge):
520540

521541
# Validate that an edge to this node+field doesn't already exist
522542
input_edges = self._get_input_edges(edge.destination.node_id, edge.destination.field)
523-
if len(input_edges) > 0 and not isinstance(to_node, CollectInvocation):
543+
if len(input_edges) > 0 and (
544+
not isinstance(to_node, CollectInvocation) or edge.destination.field != ITEM_FIELD
545+
):
524546
raise InvalidEdgeError(f"Edge already exists ({edge})")
525547

526548
# Validate that no cycles would be created
@@ -546,8 +568,10 @@ def _validate_edge(self, edge: Edge):
546568
raise InvalidEdgeError(f"Iterator output type does not match iterator input type ({edge}): {err}")
547569

548570
# Validate if collector input type matches output type (if this edge results in both being set)
549-
if isinstance(to_node, CollectInvocation) and edge.destination.field == ITEM_FIELD:
550-
err = self._is_collector_connection_valid(edge.destination.node_id, new_input=edge.source)
571+
if isinstance(to_node, CollectInvocation) and edge.destination.field in (ITEM_FIELD, COLLECTION_FIELD):
572+
err = self._is_collector_connection_valid(
573+
edge.destination.node_id, new_input=edge.source, new_input_field=edge.destination.field
574+
)
551575
if err is not None:
552576
raise InvalidEdgeError(f"Collector output type does not match collector input type ({edge}): {err}")
553577

@@ -676,76 +700,152 @@ def _is_iterator_connection_valid(
676700

677701
# Collector input type must match all iterator output types
678702
if isinstance(input_node, CollectInvocation):
679-
collector_inputs = self._get_input_edges(input_node.id, ITEM_FIELD)
680-
if len(collector_inputs) == 0:
681-
return "Iterator input collector must have at least one item input edge"
682-
683-
# Traverse the graph to find the first collector input edge. Collectors validate that their collection
684-
# inputs are all of the same type, so we can use the first input edge to determine the collector's type
685-
first_collector_input_edge = collector_inputs[0]
686-
first_collector_input_type = get_output_field_type(
687-
self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
688-
)
689-
resolved_collector_type = (
690-
first_collector_input_type
691-
if get_origin(first_collector_input_type) is None
692-
else get_args(first_collector_input_type)
693-
)
694-
if not all((are_connection_types_compatible(resolved_collector_type, t) for t in output_field_types)):
703+
input_root_type = self._get_collector_input_root_type(input_node.id)
704+
if input_root_type is None:
705+
return "Iterator input collector must have at least one item or collection input edge"
706+
if not all((are_connection_types_compatible(input_root_type, t) for t in output_field_types)):
695707
return "Iterator collection type must match all iterator output types"
696708

697709
return None
698710

711+
def _resolve_collector_input_types(self, node_id: str, visited: Optional[set[str]] = None) -> set[Any]:
712+
"""Resolves possible item types for a collector's inputs, recursively following chained collectors."""
713+
visited = visited or set()
714+
if node_id in visited:
715+
return set()
716+
visited.add(node_id)
717+
718+
input_types: set[Any] = set()
719+
720+
for edge in self._get_input_edges(node_id, ITEM_FIELD):
721+
input_field_type = get_output_field_type(self.get_node(edge.source.node_id), edge.source.field)
722+
resolved_types = [input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type)
723+
input_types.update(t for t in resolved_types if t != NoneType)
724+
725+
for edge in self._get_input_edges(node_id, COLLECTION_FIELD):
726+
source_node = self.get_node(edge.source.node_id)
727+
if isinstance(source_node, CollectInvocation) and edge.source.field == COLLECTION_FIELD:
728+
input_types.update(self._resolve_collector_input_types(source_node.id, visited.copy()))
729+
continue
730+
731+
input_field_type = get_output_field_type(source_node, edge.source.field)
732+
input_types.update(extract_collection_item_types(input_field_type))
733+
734+
return input_types
735+
736+
def _get_collector_input_root_type(self, node_id: str) -> Any | None:
737+
input_types = self._resolve_collector_input_types(node_id)
738+
non_any_input_types = {t for t in input_types if t != Any}
739+
if len(non_any_input_types) == 0 and Any in input_types:
740+
return Any
741+
if len(non_any_input_types) == 0:
742+
return None
743+
744+
type_tree = nx.DiGraph()
745+
type_tree.add_nodes_from(non_any_input_types)
746+
type_tree.add_edges_from([e for e in itertools.permutations(non_any_input_types, 2) if issubclass(e[1], e[0])])
747+
type_degrees = type_tree.in_degree(type_tree.nodes)
748+
root_types = [t[0] for t in type_degrees if t[1] == 0] # type: ignore
749+
if len(root_types) != 1:
750+
return Any
751+
return root_types[0]
752+
699753
def _is_collector_connection_valid(
700754
self,
701755
node_id: str,
702756
new_input: Optional[EdgeConnection] = None,
757+
new_input_field: Optional[str] = None,
703758
new_output: Optional[EdgeConnection] = None,
704759
) -> str | None:
705-
inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)]
760+
item_inputs = [e.source for e in self._get_input_edges(node_id, ITEM_FIELD)]
761+
collection_inputs = [e.source for e in self._get_input_edges(node_id, COLLECTION_FIELD)]
706762
outputs = [e.destination for e in self._get_output_edges(node_id, COLLECTION_FIELD)]
707763

708764
if new_input is not None:
709-
inputs.append(new_input)
765+
field = new_input_field or ITEM_FIELD
766+
if field == ITEM_FIELD:
767+
item_inputs.append(new_input)
768+
elif field == COLLECTION_FIELD:
769+
collection_inputs.append(new_input)
710770
if new_output is not None:
711771
outputs.append(new_output)
712772

713-
# Get input and output fields (the fields linked to the iterator's input/output)
714-
input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in inputs]
773+
if len(item_inputs) == 0 and len(collection_inputs) == 0:
774+
return "Collector must have at least one item or collection input edge"
775+
776+
# Get input and output fields (the fields linked to the collector's input/output)
777+
item_input_field_types = [get_output_field_type(self.get_node(e.node_id), e.field) for e in item_inputs]
778+
collection_input_field_types = [
779+
get_output_field_type(self.get_node(e.node_id), e.field) for e in collection_inputs
780+
]
715781
output_field_types = [get_input_field_type(self.get_node(e.node_id), e.field) for e in outputs]
716782

783+
if not all((is_list_or_contains_list(t) or is_any(t) for t in collection_input_field_types)):
784+
return "Collector collection input must be a collection"
785+
717786
# Validate that all inputs are derived from or match a single type
718787
input_field_types = {
719788
resolved_type
720-
for input_field_type in input_field_types
789+
for input_field_type in item_input_field_types
721790
for resolved_type in (
722791
[input_field_type] if get_origin(input_field_type) is None else get_args(input_field_type)
723792
)
724793
if resolved_type != NoneType
725794
} # Get unique types
795+
796+
for input_conn, input_field_type in zip(collection_inputs, collection_input_field_types, strict=False):
797+
source_node = self.get_node(input_conn.node_id)
798+
if isinstance(source_node, CollectInvocation) and input_conn.field == COLLECTION_FIELD:
799+
input_field_types.update(self._resolve_collector_input_types(source_node.id))
800+
continue
801+
input_field_types.update(extract_collection_item_types(input_field_type))
802+
803+
non_any_input_field_types = {t for t in input_field_types if t != Any}
726804
type_tree = nx.DiGraph()
727-
type_tree.add_nodes_from(input_field_types)
728-
type_tree.add_edges_from([e for e in itertools.permutations(input_field_types, 2) if issubclass(e[1], e[0])])
805+
type_tree.add_nodes_from(non_any_input_field_types)
806+
type_tree.add_edges_from(
807+
[e for e in itertools.permutations(non_any_input_field_types, 2) if issubclass(e[1], e[0])]
808+
)
729809
type_degrees = type_tree.in_degree(type_tree.nodes)
730-
if sum((t[1] == 0 for t in type_degrees)) != 1: # type: ignore
810+
root_types = [t[0] for t in type_degrees if t[1] == 0] # type: ignore
811+
if len(root_types) > 1:
731812
return "Collector input collection items must be of a single type"
732813

733-
# Get the input root type
734-
input_root_type = next(t[0] for t in type_degrees if t[1] == 0) # type: ignore
814+
# Get the input root type (if known)
815+
input_root_type = root_types[0] if len(root_types) == 1 else None
735816

736817
# Verify that all outputs are lists
737818
if not all(is_list_or_contains_list(t) or is_any(t) for t in output_field_types):
738819
return "Collector output must connect to a collection input"
739820

740821
# Verify that all outputs match the input type (are a base class or the same class)
741-
if not all(
742-
is_any(t)
743-
or is_union_subtype(input_root_type, get_args(t)[0])
744-
or issubclass(input_root_type, get_args(t)[0])
745-
for t in output_field_types
746-
):
822+
if input_root_type is not None:
823+
if not all(
824+
is_any(t)
825+
or is_union_subtype(input_root_type, get_args(t)[0])
826+
or issubclass(input_root_type, get_args(t)[0])
827+
for t in output_field_types
828+
):
829+
return "Collector outputs must connect to a collection input with a matching type"
830+
elif any(not is_any(t) and get_args(t)[0] != Any for t in output_field_types):
747831
return "Collector outputs must connect to a collection input with a matching type"
748832

833+
# If this collector outputs to another collector's collection input, validate against the downstream
834+
# collector's resolved input type (if available).
835+
for output in outputs:
836+
output_node = self.get_node(output.node_id)
837+
if not isinstance(output_node, CollectInvocation) or output.field != COLLECTION_FIELD:
838+
continue
839+
output_root_type = self._get_collector_input_root_type(output_node.id)
840+
if output_root_type is None:
841+
continue
842+
if input_root_type is None:
843+
if output_root_type != Any:
844+
return "Collector outputs must connect to a collection input with a matching type"
845+
continue
846+
if not are_connection_types_compatible(input_root_type, output_root_type):
847+
return "Collector outputs must connect to a collection input with a matching type"
848+
749849
return None
750850

751851
def nx_graph(self) -> nx.DiGraph:
@@ -1211,8 +1311,19 @@ def _prepare_inputs(self, node: BaseInvocation):
12111311
if isinstance(node, CollectInvocation):
12121312
item_edges = [e for e in input_edges if e.destination.field == ITEM_FIELD]
12131313
item_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id))
1214-
1215-
output_collection = [copydeep(getattr(self.results[e.source.node_id], e.source.field)) for e in item_edges]
1314+
collection_edges = [e for e in input_edges if e.destination.field == COLLECTION_FIELD]
1315+
collection_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id))
1316+
1317+
output_collection = []
1318+
for edge in collection_edges:
1319+
source_value = copydeep(getattr(self.results[edge.source.node_id], edge.source.field))
1320+
if isinstance(source_value, list):
1321+
output_collection.extend(source_value)
1322+
else:
1323+
output_collection.append(source_value)
1324+
output_collection.extend(
1325+
copydeep(getattr(self.results[e.source.node_id], e.source.field)) for e in item_edges
1326+
)
12161327
node.collection = output_collection
12171328
else:
12181329
for edge in input_edges:

invokeai/frontend/web/public/locales/en.json

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3107,9 +3107,9 @@
31073107
"whatsNew": {
31083108
"whatsNewInInvoke": "What's New in Invoke",
31093109
"items": [
3110-
"FLUX.2 Klein Support: InvokeAI now supports the new FLUX.2 Klein models (4B and 9B variants) with GGUF, FP8, and Diffusers formats. Features include txt2img, img2img, inpainting, and outpainting. See 'Starter Models' to get started.",
3111-
"DyPE support for FLUX models improves high-resolution (>1536 px up to 4K) images. Go to the 'Advanced Options' section to activate.",
3112-
"Z-Image Turbo diversity: Active 'Seed Variance Enhancer' under 'Advanced Options' to add diversitiy to your ZiT gens."
3110+
"Multi-user mode supports multiple isolated users on the same server.",
3111+
"Enhanced support for Z-Image and FLUX.2 Models.",
3112+
"Multiple user interface enhancements and new canvas features."
31133113
],
31143114
"takeUserSurvey": "📣 Let us know how you like InvokeAI. Take our User Experience Survey!",
31153115
"readReleaseNotes": "Read Release Notes",

invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.test.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,14 @@ describe(getCollectItemType.name, () => {
4141
const result = getCollectItemType({ add: addWithoutOutputValue, collect }, [n2, n1], [e1], n1.id);
4242
expect(result).toBeNull();
4343
});
44+
45+
it('should return the upstream collect item type for chained collects', () => {
46+
const n1 = buildNode(collect);
47+
const n2 = buildNode(collect);
48+
const n3 = buildNode(add);
49+
const e1 = buildEdge(n3.id, 'value', n1.id, 'item');
50+
const e2 = buildEdge(n1.id, 'collection', n2.id, 'collection');
51+
const result = getCollectItemType(templates, [n1, n2, n3], [e1, e2], n2.id);
52+
expect(result).toEqual<FieldType>({ name: 'IntegerField', cardinality: 'SINGLE', batch: false });
53+
});
4454
});

invokeai/frontend/web/src/features/nodes/store/util/getCollectItemType.ts

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@ import type { Templates } from 'features/nodes/store/types';
22
import type { FieldType } from 'features/nodes/types/field';
33
import type { AnyEdge, AnyNode } from 'features/nodes/types/invocation';
44

5+
const toItemType = (fieldType: FieldType): FieldType | null => {
6+
if (fieldType.name === 'CollectionField') {
7+
return null;
8+
}
9+
if (fieldType.cardinality === 'COLLECTION' || fieldType.cardinality === 'SINGLE_OR_COLLECTION') {
10+
return { ...fieldType, cardinality: 'SINGLE' };
11+
}
12+
return fieldType;
13+
};
14+
515
/**
616
* Given a collect node, return the type of the items it collects. The graph is traversed to find the first node and
717
* field connected to the collector's `item` input. The field type of that field is returned, else null if there is no
@@ -18,21 +28,56 @@ export const getCollectItemType = (
1828
edges: AnyEdge[],
1929
nodeId: string
2030
): FieldType | null => {
21-
const firstEdgeToCollect = edges.find((edge) => edge.target === nodeId && edge.targetHandle === 'item');
22-
if (!firstEdgeToCollect?.sourceHandle) {
23-
return null;
24-
}
25-
const node = nodes.find((n) => n.id === firstEdgeToCollect.source);
26-
if (!node) {
27-
return null;
28-
}
29-
const template = templates[node.data.type];
30-
if (!template) {
31-
return null;
32-
}
33-
const fieldTemplate = template.outputs[firstEdgeToCollect.sourceHandle];
34-
if (!fieldTemplate) {
31+
const getCollectItemTypeInternal = (currentNodeId: string, visited: Set<string>): FieldType | null => {
32+
if (visited.has(currentNodeId)) {
33+
return null;
34+
}
35+
visited.add(currentNodeId);
36+
37+
const firstItemEdgeToCollect = edges.find((edge) => edge.target === currentNodeId && edge.targetHandle === 'item');
38+
if (firstItemEdgeToCollect?.sourceHandle) {
39+
const node = nodes.find((n) => n.id === firstItemEdgeToCollect.source);
40+
if (!node) {
41+
return null;
42+
}
43+
const template = templates[node.data.type];
44+
if (!template) {
45+
return null;
46+
}
47+
const fieldTemplate = template.outputs[firstItemEdgeToCollect.sourceHandle];
48+
if (!fieldTemplate) {
49+
return null;
50+
}
51+
return toItemType(fieldTemplate.type);
52+
}
53+
54+
const firstCollectionEdgeToCollect = edges.find(
55+
(edge) => edge.target === currentNodeId && edge.targetHandle === 'collection'
56+
);
57+
if (!firstCollectionEdgeToCollect?.sourceHandle) {
58+
return null;
59+
}
60+
const sourceNode = nodes.find((n) => n.id === firstCollectionEdgeToCollect.source);
61+
if (!sourceNode) {
62+
return null;
63+
}
64+
if (sourceNode.data.type === 'collect' && firstCollectionEdgeToCollect.sourceHandle === 'collection') {
65+
return getCollectItemTypeInternal(sourceNode.id, visited);
66+
}
67+
const sourceTemplate = templates[sourceNode.data.type];
68+
if (!sourceTemplate) {
69+
return null;
70+
}
71+
const sourceFieldTemplate = sourceTemplate.outputs[firstCollectionEdgeToCollect.sourceHandle];
72+
if (!sourceFieldTemplate) {
73+
return null;
74+
}
75+
return toItemType(sourceFieldTemplate.type);
76+
};
77+
78+
const itemType = getCollectItemTypeInternal(nodeId, new Set());
79+
if (!itemType) {
3580
return null;
3681
}
37-
return fieldTemplate.type;
82+
return itemType;
3883
};

0 commit comments

Comments
 (0)