@@ -824,6 +824,114 @@ def noop(x):
824824 self .assertEqual (len (flat .transforms ), 3 )
825825 self .assertIsInstance (flat .transforms [1 ], mt .Compose )
826826
827+ def test_multiple_children_with_mixed_map_items (self ):
828+ """Multiple internal Composes with different map_items should be handled correctly."""
829+
830+ def add_one (items ):
831+ if isinstance (items , list ):
832+ return [x + 1 for x in items ]
833+ return items + 1
834+
835+ def multiply_two (items ):
836+ if isinstance (items , list ):
837+ return [x * 2 for x in items ]
838+ return items * 2
839+
840+ # Parent with map_items=False processes the entire input as one unit
841+ # Child 1 (map_items=True) will map over each item in what it receives
842+ # Child 2 (map_items=False) will process the entire thing
843+ pipeline = mt .Compose ([
844+ mt .Compose ([add_one ], map_items = True ),
845+ mt .Compose ([multiply_two ], map_items = False ),
846+ ], map_items = False )
847+
848+ # Input [1, 2, 3]
849+ # First child with map_items=True maps add_one over [1,2,3]: [2, 3, 4]
850+ # Second child with map_items=False receives [2,3,4] and applies multiply_two: [4, 6, 8]
851+ result = pipeline ([1 , 2 , 3 ])
852+ self .assertEqual (result , [4 , 6 , 8 ])
853+
854+ def test_flatten_with_multiple_children_preserves_both (self ):
855+ """flatten() should preserve child with different map_items but flatten child with same."""
856+
857+ def noop (x ):
858+ return x
859+
860+ parent = mt .Compose ([
861+ noop ,
862+ mt .Compose ([noop , noop ], map_items = True ), # Same as parent, will be flattened
863+ mt .Compose ([noop , noop ], map_items = False ), # Different, will be preserved
864+ noop ,
865+ ])
866+ flat = parent .flatten ()
867+ # First nested Compose(map_items=True) will be flattened into parent
868+ # Second nested Compose(map_items=False) will be preserved
869+ # Result: noop + noop + noop + Compose([noop, noop]) + noop = 5 transforms
870+ self .assertEqual (len (flat .transforms ), 5 )
871+ # Check that the preserved one is at the correct position
872+ self .assertIsInstance (flat .transforms [3 ], mt .Compose )
873+ self .assertEqual (flat .transforms [3 ].map_items , False )
874+
875+ def test_three_level_nesting_respects_different_map_items (self ):
876+ """Three-level nesting with different map_items at each level."""
877+
878+ def add_one (x ):
879+ return x + 1
880+
881+ # Level 1 (outermost): map_items=True (default)
882+ # Level 2: map_items=False
883+ # Level 3: map_items=True (same as level 2, so will be flattened into level 2)
884+ innermost = mt .Compose ([add_one ], map_items = True )
885+ middle = mt .Compose ([add_one , innermost ], map_items = False )
886+ outer = mt .Compose ([middle ])
887+
888+ # Test with a simple value
889+ # outer has map_items=True (default), middle has map_items=False
890+ # So middle should be preserved and receive the input as-is
891+ result = outer (5 )
892+ # outer(5) -> maps to middle -> middle(5) with map_items=False
893+ # middle(5) -> add_one(5) = 6, then innermost(6) with map_items=True
894+ # innermost(6) -> add_one(6) = 7
895+ self .assertEqual (result , 7 )
896+
897+ def test_inverse_with_multiple_children_different_map_items (self ):
898+ """Inverse should work correctly with multiple children having different map_items."""
899+ pipeline = mt .Compose ([
900+ mt .Flip (0 ),
901+ mt .Compose ([mt .Flip (1 )], map_items = False ),
902+ mt .Compose ([mt .Flip (0 )], map_items = True ),
903+ ])
904+ data = torch .randn (2 , 4 , 4 )
905+ result = pipeline (data )
906+ restored = pipeline .inverse (result )
907+ torch .testing .assert_close (data , restored )
908+
909+ def test_flatten_with_mixed_same_and_different_map_items (self ):
910+ """flatten() should merge children with same map_items but preserve those with different."""
911+
912+ def noop (x ):
913+ return x
914+
915+ # Parent has map_items=True (default)
916+ # Child 1 has map_items=True (same as parent) -> should be flattened
917+ # Child 2 has map_items=False (different from parent) -> should NOT be flattened
918+ parent = mt .Compose ([
919+ noop ,
920+ mt .Compose ([noop , noop ], map_items = True ), # Same as parent, will be flattened
921+ mt .Compose ([noop , noop ], map_items = False ), # Different from parent, will be preserved
922+ noop ,
923+ ])
924+ flat = parent .flatten ()
925+ # After flatten:
926+ # - noop (preserved)
927+ # - 2 noops from first Compose (flattened because map_items=True matches parent)
928+ # - Compose([noop, noop], map_items=False) (preserved because different)
929+ # - noop (preserved)
930+ # Total: 5 transforms
931+ self .assertEqual (len (flat .transforms ), 5 )
932+ self .assertIsInstance (flat .transforms [3 ], mt .Compose )
933+ self .assertEqual (flat .transforms [3 ].map_items , False )
934+
827935
828936class TestComposeCallableInput (unittest .TestCase ):
829937
0 commit comments