Skip to content

Commit bf8e78e

Browse files
committed
test: add nested Compose map_items coverage and preserve flatten() attrs
Expand test coverage for nested Compose transforms with different map_items configurations, and forward map_items, unpack_items, log_stats, lazy, and overrides through Compose.flatten() so the flattened pipeline is equivalent to the original (mirroring OneOf.flatten behavior). Signed-off-by: Soumya Snigdha Kundu <soumya_snigdha.kundu@kcl.ac.uk>
1 parent 593657f commit bf8e78e

2 files changed

Lines changed: 116 additions & 1 deletion

File tree

monai/transforms/compose.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,14 @@ def flatten(self):
358358
else:
359359
new_transforms.append(t)
360360

361-
return Compose(new_transforms)
361+
return Compose(
362+
new_transforms,
363+
map_items=self.map_items,
364+
unpack_items=self.unpack_items,
365+
log_stats=self.log_stats,
366+
lazy=self._lazy,
367+
overrides=self.overrides,
368+
)
362369

363370
def __len__(self):
364371
"""Return number of transformations."""

tests/transforms/compose/test_compose.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,6 +824,114 @@ def noop(x):
824824
self.assertEqual(len(flat.transforms), 3)
825825
self.assertIsInstance(flat.transforms[1], mt.Compose)
826826

827+
def test_multiple_children_with_mixed_map_items(self):
828+
"""Multiple internal Composes with different map_items should be handled correctly."""
829+
830+
def add_one(items):
831+
if isinstance(items, list):
832+
return [x + 1 for x in items]
833+
return items + 1
834+
835+
def multiply_two(items):
836+
if isinstance(items, list):
837+
return [x * 2 for x in items]
838+
return items * 2
839+
840+
# Parent with map_items=False processes the entire input as one unit
841+
# Child 1 (map_items=True) will map over each item in what it receives
842+
# Child 2 (map_items=False) will process the entire thing
843+
pipeline = mt.Compose([
844+
mt.Compose([add_one], map_items=True),
845+
mt.Compose([multiply_two], map_items=False),
846+
], map_items=False)
847+
848+
# Input [1, 2, 3]
849+
# First child with map_items=True maps add_one over [1,2,3]: [2, 3, 4]
850+
# Second child with map_items=False receives [2,3,4] and applies multiply_two: [4, 6, 8]
851+
result = pipeline([1, 2, 3])
852+
self.assertEqual(result, [4, 6, 8])
853+
854+
def test_flatten_with_multiple_children_preserves_both(self):
855+
"""flatten() should preserve child with different map_items but flatten child with same."""
856+
857+
def noop(x):
858+
return x
859+
860+
parent = mt.Compose([
861+
noop,
862+
mt.Compose([noop, noop], map_items=True), # Same as parent, will be flattened
863+
mt.Compose([noop, noop], map_items=False), # Different, will be preserved
864+
noop,
865+
])
866+
flat = parent.flatten()
867+
# First nested Compose(map_items=True) will be flattened into parent
868+
# Second nested Compose(map_items=False) will be preserved
869+
# Result: noop + noop + noop + Compose([noop, noop]) + noop = 5 transforms
870+
self.assertEqual(len(flat.transforms), 5)
871+
# Check that the preserved one is at the correct position
872+
self.assertIsInstance(flat.transforms[3], mt.Compose)
873+
self.assertEqual(flat.transforms[3].map_items, False)
874+
875+
def test_three_level_nesting_respects_different_map_items(self):
876+
"""Three-level nesting with different map_items at each level."""
877+
878+
def add_one(x):
879+
return x + 1
880+
881+
# Level 1 (outermost): map_items=True (default)
882+
# Level 2: map_items=False
883+
# Level 3: map_items=True (same as level 2, so will be flattened into level 2)
884+
innermost = mt.Compose([add_one], map_items=True)
885+
middle = mt.Compose([add_one, innermost], map_items=False)
886+
outer = mt.Compose([middle])
887+
888+
# Test with a simple value
889+
# outer has map_items=True (default), middle has map_items=False
890+
# So middle should be preserved and receive the input as-is
891+
result = outer(5)
892+
# outer(5) -> maps to middle -> middle(5) with map_items=False
893+
# middle(5) -> add_one(5) = 6, then innermost(6) with map_items=True
894+
# innermost(6) -> add_one(6) = 7
895+
self.assertEqual(result, 7)
896+
897+
def test_inverse_with_multiple_children_different_map_items(self):
898+
"""Inverse should work correctly with multiple children having different map_items."""
899+
pipeline = mt.Compose([
900+
mt.Flip(0),
901+
mt.Compose([mt.Flip(1)], map_items=False),
902+
mt.Compose([mt.Flip(0)], map_items=True),
903+
])
904+
data = torch.randn(2, 4, 4)
905+
result = pipeline(data)
906+
restored = pipeline.inverse(result)
907+
torch.testing.assert_close(data, restored)
908+
909+
def test_flatten_with_mixed_same_and_different_map_items(self):
910+
"""flatten() should merge children with same map_items but preserve those with different."""
911+
912+
def noop(x):
913+
return x
914+
915+
# Parent has map_items=True (default)
916+
# Child 1 has map_items=True (same as parent) -> should be flattened
917+
# Child 2 has map_items=False (different from parent) -> should NOT be flattened
918+
parent = mt.Compose([
919+
noop,
920+
mt.Compose([noop, noop], map_items=True), # Same as parent, will be flattened
921+
mt.Compose([noop, noop], map_items=False), # Different from parent, will be preserved
922+
noop,
923+
])
924+
flat = parent.flatten()
925+
# After flatten:
926+
# - noop (preserved)
927+
# - 2 noops from first Compose (flattened because map_items=True matches parent)
928+
# - Compose([noop, noop], map_items=False) (preserved because different)
929+
# - noop (preserved)
930+
# Total: 5 transforms
931+
self.assertEqual(len(flat.transforms), 5)
932+
self.assertIsInstance(flat.transforms[3], mt.Compose)
933+
self.assertEqual(flat.transforms[3].map_items, False)
934+
827935

828936
class TestComposeCallableInput(unittest.TestCase):
829937

0 commit comments

Comments
 (0)