Skip to content

Commit f89ab6f

Browse files
committed
test: add nested Compose map_items coverage and preserve flatten() attrs
Expand test coverage for nested Compose transforms with different map_items configurations, and forward map_items, unpack_items, log_stats, lazy, and overrides through Compose.flatten() so the flattened pipeline is equivalent to the original (mirroring OneOf.flatten behavior). Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent 593657f commit f89ab6f

2 files changed

Lines changed: 125 additions & 1 deletion

File tree

monai/transforms/compose.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,14 @@ def flatten(self):
358358
else:
359359
new_transforms.append(t)
360360

361-
return Compose(new_transforms)
361+
return Compose(
362+
new_transforms,
363+
map_items=self.map_items,
364+
unpack_items=self.unpack_items,
365+
log_stats=self.log_stats,
366+
lazy=self._lazy,
367+
overrides=self.overrides,
368+
)
362369

363370
def __len__(self):
364371
"""Return number of transformations."""

tests/transforms/compose/test_compose.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,123 @@ def noop(x):
824824
self.assertEqual(len(flat.transforms), 3)
825825
self.assertIsInstance(flat.transforms[1], mt.Compose)
826826

827+
def test_multiple_children_with_mixed_map_items(self):
828+
"""Multiple internal Composes with different map_items should be handled correctly."""
829+
830+
def add_one(items):
831+
if isinstance(items, list):
832+
return [x + 1 for x in items]
833+
return items + 1
834+
835+
def multiply_two(items):
836+
if isinstance(items, list):
837+
return [x * 2 for x in items]
838+
return items * 2
839+
840+
# Parent with map_items=False processes the entire input as one unit
841+
# Child 1 (map_items=True) will map over each item in what it receives
842+
# Child 2 (map_items=False) will process the entire thing
843+
pipeline = mt.Compose(
844+
[
845+
mt.Compose([add_one], map_items=True),
846+
mt.Compose([multiply_two], map_items=False),
847+
],
848+
map_items=False,
849+
)
850+
851+
# Input [1, 2, 3]
852+
# First child with map_items=True maps add_one over [1,2,3]: [2, 3, 4]
853+
# Second child with map_items=False receives [2,3,4] and applies multiply_two: [4, 6, 8]
854+
result = pipeline([1, 2, 3])
855+
self.assertEqual(result, [4, 6, 8])
856+
857+
def test_flatten_with_multiple_children_preserves_both(self):
858+
"""flatten() should preserve child with different map_items but flatten child with same."""
859+
860+
def noop(x):
861+
return x
862+
863+
parent = mt.Compose(
864+
[
865+
noop,
866+
mt.Compose([noop, noop], map_items=True), # Same as parent, will be flattened
867+
mt.Compose([noop, noop], map_items=False), # Different, will be preserved
868+
noop,
869+
]
870+
)
871+
flat = parent.flatten()
872+
# First nested Compose(map_items=True) will be flattened into parent
873+
# Second nested Compose(map_items=False) will be preserved
874+
# Result: noop + noop + noop + Compose([noop, noop]) + noop = 5 transforms
875+
self.assertEqual(len(flat.transforms), 5)
876+
# Check that the preserved one is at the correct position
877+
self.assertIsInstance(flat.transforms[3], mt.Compose)
878+
self.assertEqual(flat.transforms[3].map_items, False)
879+
880+
def test_three_level_nesting_respects_different_map_items(self):
881+
"""Three-level nesting with different map_items at each level."""
882+
883+
def add_one(x):
884+
return x + 1
885+
886+
# Level 1 (outermost): map_items=True (default)
887+
# Level 2: map_items=False
888+
# Level 3: map_items=True (same as level 2, so will be flattened into level 2)
889+
innermost = mt.Compose([add_one], map_items=True)
890+
middle = mt.Compose([add_one, innermost], map_items=False)
891+
outer = mt.Compose([middle])
892+
893+
# Test with a simple value
894+
# outer has map_items=True (default), middle has map_items=False
895+
# So middle should be preserved and receive the input as-is
896+
result = outer(5)
897+
# outer(5) -> maps to middle -> middle(5) with map_items=False
898+
# middle(5) -> add_one(5) = 6, then innermost(6) with map_items=True
899+
# innermost(6) -> add_one(6) = 7
900+
self.assertEqual(result, 7)
901+
902+
def test_inverse_with_multiple_children_different_map_items(self):
903+
"""Inverse should work correctly with multiple children having different map_items."""
904+
pipeline = mt.Compose(
905+
[
906+
mt.Flip(0),
907+
mt.Compose([mt.Flip(1)], map_items=False),
908+
mt.Compose([mt.Flip(0)], map_items=True),
909+
]
910+
)
911+
data = torch.randn(2, 4, 4)
912+
result = pipeline(data)
913+
restored = pipeline.inverse(result)
914+
torch.testing.assert_close(data, restored)
915+
916+
def test_flatten_with_mixed_same_and_different_map_items(self):
917+
"""flatten() should merge children with same map_items but preserve those with different."""
918+
919+
def noop(x):
920+
return x
921+
922+
# Parent has map_items=True (default)
923+
# Child 1 has map_items=True (same as parent) -> should be flattened
924+
# Child 2 has map_items=False (different from parent) -> should NOT be flattened
925+
parent = mt.Compose(
926+
[
927+
noop,
928+
mt.Compose([noop, noop], map_items=True), # Same as parent, will be flattened
929+
mt.Compose([noop, noop], map_items=False), # Different from parent, will be preserved
930+
noop,
931+
]
932+
)
933+
flat = parent.flatten()
934+
# After flatten:
935+
# - noop (preserved)
936+
# - 2 noops from first Compose (flattened because map_items=True matches parent)
937+
# - Compose([noop, noop], map_items=False) (preserved because different)
938+
# - noop (preserved)
939+
# Total: 5 transforms
940+
self.assertEqual(len(flat.transforms), 5)
941+
self.assertIsInstance(flat.transforms[3], mt.Compose)
942+
self.assertEqual(flat.transforms[3].map_items, False)
943+
827944

828945
class TestComposeCallableInput(unittest.TestCase):
829946

0 commit comments

Comments
 (0)