1919from monai .data .utils import compute_shape_offset
2020
2121
22- class TestComputeShapeOffsetRegression (unittest .TestCase ):
23- """Regression tests for ` compute_shape_offset` input-shape handling ."""
22+ class TestComputeShapeOffset (unittest .TestCase ):
23+ """Unit tests for :func:`monai.data.utils. compute_shape_offset`."""
2424
2525 def test_pytorch_size_input (self ):
2626 """Validate `torch.Size` input produces expected shape and offset.
@@ -42,6 +42,39 @@ def test_pytorch_size_input(self):
4242 # 3. Prove it successfully processed the shape by checking its length
4343 self .assertEqual (len (shape ), 3 )
4444
45+ def setUp (self ):
46+ """Set up a 4x4 identity affine used across all test cases."""
47+ self .affine = np .eye (4 )
48+
49+ def test_numpy_array_input (self ):
50+ """Verify compute_shape_offset accepts a numpy array as spatial_shape."""
51+ shape = np .array ([64 , 64 , 64 ])
52+ out_shape , _ = compute_shape_offset (shape , self .affine , self .affine )
53+ self .assertEqual (len (out_shape ), 3 )
54+
55+ def test_list_input (self ):
56+ """Verify compute_shape_offset accepts a plain list as spatial_shape."""
57+ shape = [64 , 64 , 64 ]
58+ out_shape , _ = compute_shape_offset (shape , self .affine , self .affine )
59+ self .assertEqual (len (out_shape ), 3 )
60+
61+ def test_torch_tensor_input (self ):
62+ """Verify compute_shape_offset accepts a torch.Tensor as spatial_shape.
63+
64+ This path broke in PyTorch >= 2.9 because np.array() relied on the
65+ non-tuple sequence indexing protocol that PyTorch removed. Wrapping with
66+ tuple() fixes it.
67+ """
68+ shape = torch .tensor ([64 , 64 , 64 ])
69+ out_shape , _ = compute_shape_offset (shape , self .affine , self .affine )
70+ self .assertEqual (len (out_shape ), 3 )
71+
72+ def test_identity_affines_preserve_shape (self ):
73+ """Verify that identity in/out affines produce an output shape matching the input."""
74+ shape = torch .tensor ([32 , 48 , 16 ])
75+ out_shape , _ = compute_shape_offset (shape , self .affine , self .affine )
76+ np .testing .assert_allclose (np .array (out_shape , dtype = float ), shape .numpy ().astype (float ), atol = 1e-5 )
77+
4578
4679if __name__ == "__main__" :
4780 unittest .main ()
0 commit comments