@@ -824,6 +824,123 @@ def noop(x):
824824 self .assertEqual (len (flat .transforms ), 3 )
825825 self .assertIsInstance (flat .transforms [1 ], mt .Compose )
826826
827+ def test_multiple_children_with_mixed_map_items (self ):
828+ """Multiple internal Composes with different map_items should be handled correctly."""
829+
830+ def add_one (items ):
831+ if isinstance (items , list ):
832+ return [x + 1 for x in items ]
833+ return items + 1
834+
835+ def multiply_two (items ):
836+ if isinstance (items , list ):
837+ return [x * 2 for x in items ]
838+ return items * 2
839+
840+ # Parent with map_items=False processes the entire input as one unit
841+ # Child 1 (map_items=True) will map over each item in what it receives
842+ # Child 2 (map_items=False) will process the entire thing
843+ pipeline = mt .Compose (
844+ [
845+ mt .Compose ([add_one ], map_items = True ),
846+ mt .Compose ([multiply_two ], map_items = False ),
847+ ],
848+ map_items = False ,
849+ )
850+
851+ # Input [1, 2, 3]
852+ # First child with map_items=True maps add_one over [1,2,3]: [2, 3, 4]
853+ # Second child with map_items=False receives [2,3,4] and applies multiply_two: [4, 6, 8]
854+ result = pipeline ([1 , 2 , 3 ])
855+ self .assertEqual (result , [4 , 6 , 8 ])
856+
857+ def test_flatten_with_multiple_children_preserves_both (self ):
858+ """flatten() should preserve child with different map_items but flatten child with same."""
859+
860+ def noop (x ):
861+ return x
862+
863+ parent = mt .Compose (
864+ [
865+ noop ,
866+ mt .Compose ([noop , noop ], map_items = True ), # Same as parent, will be flattened
867+ mt .Compose ([noop , noop ], map_items = False ), # Different, will be preserved
868+ noop ,
869+ ]
870+ )
871+ flat = parent .flatten ()
872+ # First nested Compose(map_items=True) will be flattened into parent
873+ # Second nested Compose(map_items=False) will be preserved
874+ # Result: noop + noop + noop + Compose([noop, noop]) + noop = 5 transforms
875+ self .assertEqual (len (flat .transforms ), 5 )
876+ # Check that the preserved one is at the correct position
877+ self .assertIsInstance (flat .transforms [3 ], mt .Compose )
878+ self .assertEqual (flat .transforms [3 ].map_items , False )
879+
880+ def test_three_level_nesting_respects_different_map_items (self ):
881+ """Three-level nesting with different map_items at each level."""
882+
883+ def add_one (x ):
884+ return x + 1
885+
886+ # Level 1 (outermost): map_items=True (default)
887+ # Level 2: map_items=False
888+ # Level 3: map_items=True (same as level 2, so will be flattened into level 2)
889+ innermost = mt .Compose ([add_one ], map_items = True )
890+ middle = mt .Compose ([add_one , innermost ], map_items = False )
891+ outer = mt .Compose ([middle ])
892+
893+ # Test with a simple value
894+ # outer has map_items=True (default), middle has map_items=False
895+ # So middle should be preserved and receive the input as-is
896+ result = outer (5 )
897+ # outer(5) -> maps to middle -> middle(5) with map_items=False
898+ # middle(5) -> add_one(5) = 6, then innermost(6) with map_items=True
899+ # innermost(6) -> add_one(6) = 7
900+ self .assertEqual (result , 7 )
901+
902+ def test_inverse_with_multiple_children_different_map_items (self ):
903+ """Inverse should work correctly with multiple children having different map_items."""
904+ pipeline = mt .Compose (
905+ [
906+ mt .Flip (0 ),
907+ mt .Compose ([mt .Flip (1 )], map_items = False ),
908+ mt .Compose ([mt .Flip (0 )], map_items = True ),
909+ ]
910+ )
911+ data = torch .randn (2 , 4 , 4 )
912+ result = pipeline (data )
913+ restored = pipeline .inverse (result )
914+ torch .testing .assert_close (data , restored )
915+
916+ def test_flatten_with_mixed_same_and_different_map_items (self ):
917+ """flatten() should merge children with same map_items but preserve those with different."""
918+
919+ def noop (x ):
920+ return x
921+
922+ # Parent has map_items=True (default)
923+ # Child 1 has map_items=True (same as parent) -> should be flattened
924+ # Child 2 has map_items=False (different from parent) -> should NOT be flattened
925+ parent = mt .Compose (
926+ [
927+ noop ,
928+ mt .Compose ([noop , noop ], map_items = True ), # Same as parent, will be flattened
929+ mt .Compose ([noop , noop ], map_items = False ), # Different from parent, will be preserved
930+ noop ,
931+ ]
932+ )
933+ flat = parent .flatten ()
934+ # After flatten:
935+ # - noop (preserved)
936+ # - 2 noops from first Compose (flattened because map_items=True matches parent)
937+ # - Compose([noop, noop], map_items=False) (preserved because different)
938+ # - noop (preserved)
939+ # Total: 5 transforms
940+ self .assertEqual (len (flat .transforms ), 5 )
941+ self .assertIsInstance (flat .transforms [3 ], mt .Compose )
942+ self .assertEqual (flat .transforms [3 ].map_items , False )
943+
827944
828945class TestComposeCallableInput (unittest .TestCase ):
829946
0 commit comments