@@ -123,6 +123,23 @@ def is_any(t: Any) -> bool:
123123 return t == Any or Any in get_args (t )
124124
125125
126+ def extract_collection_item_types (t : Any ) -> set [Any ]:
127+ """Extracts list item types from a collection annotation, including unions containing list branches."""
128+ if is_any (t ):
129+ return {Any }
130+
131+ if get_origin (t ) is list :
132+ return {arg for arg in get_args (t ) if arg != NoneType }
133+
134+ item_types : set [Any ] = set ()
135+ for arg in get_args (t ):
136+ if is_any (arg ):
137+ item_types .add (Any )
138+ elif get_origin (arg ) is list :
139+ item_types .update (item_arg for item_arg in get_args (arg ) if item_arg != NoneType )
140+ return item_types
141+
142+
126143def are_connection_types_compatible (from_type : Any , to_type : Any ) -> bool :
127144 if not from_type or not to_type :
128145 return False
@@ -280,7 +297,7 @@ class CollectInvocationOutput(BaseInvocationOutput):
280297 )
281298
282299
283- @invocation ("collect" , version = "1.0 .0" )
300+ @invocation ("collect" , version = "1.1 .0" )
284301class CollectInvocation (BaseInvocation ):
285302 """Collects values into a collection"""
286303
@@ -292,7 +309,10 @@ class CollectInvocation(BaseInvocation):
292309 input = Input .Connection ,
293310 )
294311 collection : list [Any ] = InputField (
295- description = "The collection, will be provided on execution" , default = [], ui_hidden = True
312+ description = "An optional collection to append to" ,
313+ default = [],
314+ ui_type = UIType ._Collection ,
315+ input = Input .Connection ,
296316 )
297317
298318 def invoke (self , context : InvocationContext ) -> CollectInvocationOutput :
@@ -520,7 +540,9 @@ def _validate_edge(self, edge: Edge):
520540
521541 # Validate that an edge to this node+field doesn't already exist
522542 input_edges = self ._get_input_edges (edge .destination .node_id , edge .destination .field )
523- if len (input_edges ) > 0 and not isinstance (to_node , CollectInvocation ):
543+ if len (input_edges ) > 0 and (
544+ not isinstance (to_node , CollectInvocation ) or edge .destination .field != ITEM_FIELD
545+ ):
524546 raise InvalidEdgeError (f"Edge already exists ({ edge } )" )
525547
526548 # Validate that no cycles would be created
@@ -546,8 +568,10 @@ def _validate_edge(self, edge: Edge):
546568 raise InvalidEdgeError (f"Iterator output type does not match iterator input type ({ edge } ): { err } " )
547569
548570 # Validate if collector input type matches output type (if this edge results in both being set)
549- if isinstance (to_node , CollectInvocation ) and edge .destination .field == ITEM_FIELD :
550- err = self ._is_collector_connection_valid (edge .destination .node_id , new_input = edge .source )
571+ if isinstance (to_node , CollectInvocation ) and edge .destination .field in (ITEM_FIELD , COLLECTION_FIELD ):
572+ err = self ._is_collector_connection_valid (
573+ edge .destination .node_id , new_input = edge .source , new_input_field = edge .destination .field
574+ )
551575 if err is not None :
552576 raise InvalidEdgeError (f"Collector output type does not match collector input type ({ edge } ): { err } " )
553577
@@ -676,76 +700,152 @@ def _is_iterator_connection_valid(
676700
677701 # Collector input type must match all iterator output types
678702 if isinstance (input_node , CollectInvocation ):
679- collector_inputs = self ._get_input_edges (input_node .id , ITEM_FIELD )
680- if len (collector_inputs ) == 0 :
681- return "Iterator input collector must have at least one item input edge"
682-
683- # Traverse the graph to find the first collector input edge. Collectors validate that their collection
684- # inputs are all of the same type, so we can use the first input edge to determine the collector's type
685- first_collector_input_edge = collector_inputs [0 ]
686- first_collector_input_type = get_output_field_type (
687- self .get_node (first_collector_input_edge .source .node_id ), first_collector_input_edge .source .field
688- )
689- resolved_collector_type = (
690- first_collector_input_type
691- if get_origin (first_collector_input_type ) is None
692- else get_args (first_collector_input_type )
693- )
694- if not all ((are_connection_types_compatible (resolved_collector_type , t ) for t in output_field_types )):
703+ input_root_type = self ._get_collector_input_root_type (input_node .id )
704+ if input_root_type is None :
705+ return "Iterator input collector must have at least one item or collection input edge"
706+ if not all ((are_connection_types_compatible (input_root_type , t ) for t in output_field_types )):
695707 return "Iterator collection type must match all iterator output types"
696708
697709 return None
698710
711+ def _resolve_collector_input_types (self , node_id : str , visited : Optional [set [str ]] = None ) -> set [Any ]:
712+ """Resolves possible item types for a collector's inputs, recursively following chained collectors."""
713+ visited = visited or set ()
714+ if node_id in visited :
715+ return set ()
716+ visited .add (node_id )
717+
718+ input_types : set [Any ] = set ()
719+
720+ for edge in self ._get_input_edges (node_id , ITEM_FIELD ):
721+ input_field_type = get_output_field_type (self .get_node (edge .source .node_id ), edge .source .field )
722+ resolved_types = [input_field_type ] if get_origin (input_field_type ) is None else get_args (input_field_type )
723+ input_types .update (t for t in resolved_types if t != NoneType )
724+
725+ for edge in self ._get_input_edges (node_id , COLLECTION_FIELD ):
726+ source_node = self .get_node (edge .source .node_id )
727+ if isinstance (source_node , CollectInvocation ) and edge .source .field == COLLECTION_FIELD :
728+ input_types .update (self ._resolve_collector_input_types (source_node .id , visited .copy ()))
729+ continue
730+
731+ input_field_type = get_output_field_type (source_node , edge .source .field )
732+ input_types .update (extract_collection_item_types (input_field_type ))
733+
734+ return input_types
735+
736+ def _get_collector_input_root_type (self , node_id : str ) -> Any | None :
737+ input_types = self ._resolve_collector_input_types (node_id )
738+ non_any_input_types = {t for t in input_types if t != Any }
739+ if len (non_any_input_types ) == 0 and Any in input_types :
740+ return Any
741+ if len (non_any_input_types ) == 0 :
742+ return None
743+
744+ type_tree = nx .DiGraph ()
745+ type_tree .add_nodes_from (non_any_input_types )
746+ type_tree .add_edges_from ([e for e in itertools .permutations (non_any_input_types , 2 ) if issubclass (e [1 ], e [0 ])])
747+ type_degrees = type_tree .in_degree (type_tree .nodes )
748+ root_types = [t [0 ] for t in type_degrees if t [1 ] == 0 ] # type: ignore
749+ if len (root_types ) != 1 :
750+ return Any
751+ return root_types [0 ]
752+
699753 def _is_collector_connection_valid (
700754 self ,
701755 node_id : str ,
702756 new_input : Optional [EdgeConnection ] = None ,
757+ new_input_field : Optional [str ] = None ,
703758 new_output : Optional [EdgeConnection ] = None ,
704759 ) -> str | None :
705- inputs = [e .source for e in self ._get_input_edges (node_id , ITEM_FIELD )]
760+ item_inputs = [e .source for e in self ._get_input_edges (node_id , ITEM_FIELD )]
761+ collection_inputs = [e .source for e in self ._get_input_edges (node_id , COLLECTION_FIELD )]
706762 outputs = [e .destination for e in self ._get_output_edges (node_id , COLLECTION_FIELD )]
707763
708764 if new_input is not None :
709- inputs .append (new_input )
765+ field = new_input_field or ITEM_FIELD
766+ if field == ITEM_FIELD :
767+ item_inputs .append (new_input )
768+ elif field == COLLECTION_FIELD :
769+ collection_inputs .append (new_input )
710770 if new_output is not None :
711771 outputs .append (new_output )
712772
713- # Get input and output fields (the fields linked to the iterator's input/output)
714- input_field_types = [get_output_field_type (self .get_node (e .node_id ), e .field ) for e in inputs ]
773+ if len (item_inputs ) == 0 and len (collection_inputs ) == 0 :
774+ return "Collector must have at least one item or collection input edge"
775+
776+ # Get input and output fields (the fields linked to the collector's input/output)
777+ item_input_field_types = [get_output_field_type (self .get_node (e .node_id ), e .field ) for e in item_inputs ]
778+ collection_input_field_types = [
779+ get_output_field_type (self .get_node (e .node_id ), e .field ) for e in collection_inputs
780+ ]
715781 output_field_types = [get_input_field_type (self .get_node (e .node_id ), e .field ) for e in outputs ]
716782
783+ if not all ((is_list_or_contains_list (t ) or is_any (t ) for t in collection_input_field_types )):
784+ return "Collector collection input must be a collection"
785+
717786 # Validate that all inputs are derived from or match a single type
718787 input_field_types = {
719788 resolved_type
720- for input_field_type in input_field_types
789+ for input_field_type in item_input_field_types
721790 for resolved_type in (
722791 [input_field_type ] if get_origin (input_field_type ) is None else get_args (input_field_type )
723792 )
724793 if resolved_type != NoneType
725794 } # Get unique types
795+
796+ for input_conn , input_field_type in zip (collection_inputs , collection_input_field_types , strict = False ):
797+ source_node = self .get_node (input_conn .node_id )
798+ if isinstance (source_node , CollectInvocation ) and input_conn .field == COLLECTION_FIELD :
799+ input_field_types .update (self ._resolve_collector_input_types (source_node .id ))
800+ continue
801+ input_field_types .update (extract_collection_item_types (input_field_type ))
802+
803+ non_any_input_field_types = {t for t in input_field_types if t != Any }
726804 type_tree = nx .DiGraph ()
727- type_tree .add_nodes_from (input_field_types )
728- type_tree .add_edges_from ([e for e in itertools .permutations (input_field_types , 2 ) if issubclass (e [1 ], e [0 ])])
805+ type_tree .add_nodes_from (non_any_input_field_types )
806+ type_tree .add_edges_from (
807+ [e for e in itertools .permutations (non_any_input_field_types , 2 ) if issubclass (e [1 ], e [0 ])]
808+ )
729809 type_degrees = type_tree .in_degree (type_tree .nodes )
730- if sum ((t [1 ] == 0 for t in type_degrees )) != 1 : # type: ignore
810+ root_types = [t [0 ] for t in type_degrees if t [1 ] == 0 ] # type: ignore
811+ if len (root_types ) > 1 :
731812 return "Collector input collection items must be of a single type"
732813
733- # Get the input root type
734- input_root_type = next ( t [0 ] for t in type_degrees if t [ 1 ] == 0 ) # type: ignore
814+ # Get the input root type (if known)
815+ input_root_type = root_types [0 ] if len ( root_types ) == 1 else None
735816
736817 # Verify that all outputs are lists
737818 if not all (is_list_or_contains_list (t ) or is_any (t ) for t in output_field_types ):
738819 return "Collector output must connect to a collection input"
739820
740821 # Verify that all outputs match the input type (are a base class or the same class)
741- if not all (
742- is_any (t )
743- or is_union_subtype (input_root_type , get_args (t )[0 ])
744- or issubclass (input_root_type , get_args (t )[0 ])
745- for t in output_field_types
746- ):
822+ if input_root_type is not None :
823+ if not all (
824+ is_any (t )
825+ or is_union_subtype (input_root_type , get_args (t )[0 ])
826+ or issubclass (input_root_type , get_args (t )[0 ])
827+ for t in output_field_types
828+ ):
829+ return "Collector outputs must connect to a collection input with a matching type"
830+ elif any (not is_any (t ) and get_args (t )[0 ] != Any for t in output_field_types ):
747831 return "Collector outputs must connect to a collection input with a matching type"
748832
833+ # If this collector outputs to another collector's collection input, validate against the downstream
834+ # collector's resolved input type (if available).
835+ for output in outputs :
836+ output_node = self .get_node (output .node_id )
837+ if not isinstance (output_node , CollectInvocation ) or output .field != COLLECTION_FIELD :
838+ continue
839+ output_root_type = self ._get_collector_input_root_type (output_node .id )
840+ if output_root_type is None :
841+ continue
842+ if input_root_type is None :
843+ if output_root_type != Any :
844+ return "Collector outputs must connect to a collection input with a matching type"
845+ continue
846+ if not are_connection_types_compatible (input_root_type , output_root_type ):
847+ return "Collector outputs must connect to a collection input with a matching type"
848+
749849 return None
750850
751851 def nx_graph (self ) -> nx .DiGraph :
@@ -1211,8 +1311,19 @@ def _prepare_inputs(self, node: BaseInvocation):
12111311 if isinstance (node , CollectInvocation ):
12121312 item_edges = [e for e in input_edges if e .destination .field == ITEM_FIELD ]
12131313 item_edges .sort (key = lambda e : (self ._get_iteration_path (e .source .node_id ), e .source .node_id ))
1214-
1215- output_collection = [copydeep (getattr (self .results [e .source .node_id ], e .source .field )) for e in item_edges ]
1314+ collection_edges = [e for e in input_edges if e .destination .field == COLLECTION_FIELD ]
1315+ collection_edges .sort (key = lambda e : (self ._get_iteration_path (e .source .node_id ), e .source .node_id ))
1316+
1317+ output_collection = []
1318+ for edge in collection_edges :
1319+ source_value = copydeep (getattr (self .results [edge .source .node_id ], edge .source .field ))
1320+ if isinstance (source_value , list ):
1321+ output_collection .extend (source_value )
1322+ else :
1323+ output_collection .append (source_value )
1324+ output_collection .extend (
1325+ copydeep (getattr (self .results [e .source .node_id ], e .source .field )) for e in item_edges
1326+ )
12161327 node .collection = output_collection
12171328 else :
12181329 for edge in input_edges :
0 commit comments