Skip to content

Commit bf94b6f

Browse files
authored
Merge branch 'dev' into fix/gaussian-kernel-truncated-8780
2 parents 3289fe6 + 24f4924 commit bf94b6f

1 file changed

Lines changed: 35 additions & 2 deletions

File tree

tests/data/utils/test_compute_shape_offset.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from monai.data.utils import compute_shape_offset
2020

2121

22-
class TestComputeShapeOffsetRegression(unittest.TestCase):
23-
"""Regression tests for `compute_shape_offset` input-shape handling."""
22+
class TestComputeShapeOffset(unittest.TestCase):
23+
"""Unit tests for :func:`monai.data.utils.compute_shape_offset`."""
2424

2525
def test_pytorch_size_input(self):
2626
"""Validate `torch.Size` input produces expected shape and offset.
@@ -42,6 +42,39 @@ def test_pytorch_size_input(self):
4242
# 3. Prove it successfully processed the shape by checking its length
4343
self.assertEqual(len(shape), 3)
4444

45+
def setUp(self):
46+
"""Set up a 4x4 identity affine used across all test cases."""
47+
self.affine = np.eye(4)
48+
49+
def test_numpy_array_input(self):
50+
"""Verify compute_shape_offset accepts a numpy array as spatial_shape."""
51+
shape = np.array([64, 64, 64])
52+
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
53+
self.assertEqual(len(out_shape), 3)
54+
55+
def test_list_input(self):
56+
"""Verify compute_shape_offset accepts a plain list as spatial_shape."""
57+
shape = [64, 64, 64]
58+
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
59+
self.assertEqual(len(out_shape), 3)
60+
61+
def test_torch_tensor_input(self):
62+
"""Verify compute_shape_offset accepts a torch.Tensor as spatial_shape.
63+
64+
This path broke in PyTorch >= 2.9 because np.array() relied on the
65+
non-tuple sequence indexing protocol that PyTorch removed. Wrapping with
66+
tuple() fixes it.
67+
"""
68+
shape = torch.tensor([64, 64, 64])
69+
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
70+
self.assertEqual(len(out_shape), 3)
71+
72+
def test_identity_affines_preserve_shape(self):
73+
"""Verify that identity in/out affines produce an output shape matching the input."""
74+
shape = torch.tensor([32, 48, 16])
75+
out_shape, _ = compute_shape_offset(shape, self.affine, self.affine)
76+
np.testing.assert_allclose(np.array(out_shape, dtype=float), shape.numpy().astype(float), atol=1e-5)
77+
4578

4679
if __name__ == "__main__":
4780
unittest.main()

0 commit comments

Comments
 (0)