3030POINT_3D_WORLD = torch .tensor ([[[2 , 4 , 6 ], [8 , 10 , 12 ]], [[14 , 16 , 18 ], [20 , 22 , 24 ]]])
3131POINT_3D_IMAGE = torch .tensor ([[[- 8 , 8 , 6 ], [- 2 , 14 , 12 ]], [[4 , 20 , 18 ], [10 , 26 , 24 ]]])
3232POINT_3D_IMAGE_RAS = torch .tensor ([[[- 12 , 0 , 6 ], [- 18 , - 6 , 12 ]], [[- 24 , - 12 , 18 ], [- 30 , - 18 , 24 ]]])
33+ AFFINE_1 = torch .tensor ([[2 , 0 , 0 , 0 ], [0 , 2 , 0 , 0 ], [0 , 0 , 1 , 0 ], [0 , 0 , 0 , 1 ]])
34+ AFFINE_2 = torch .tensor ([[1 , 0 , 0 , 10 ], [0 , 1 , 0 , - 4 ], [0 , 0 , 1 , 0 ], [0 , 0 , 0 , 1 ]])
3335
3436TEST_CASES = [
37+ [MetaTensor (DATA_2D , affine = AFFINE_1 ), POINT_2D_WORLD , None , True , False , POINT_2D_IMAGE ], # use image affine
38+ [None , MetaTensor (POINT_2D_IMAGE , affine = AFFINE_1 ), None , False , False , POINT_2D_WORLD ], # use point affine
39+ [None , MetaTensor (POINT_2D_IMAGE , affine = AFFINE_1 ), AFFINE_1 , False , False , POINT_2D_WORLD ], # use input affine
40+ [None , POINT_2D_WORLD , AFFINE_1 , True , False , POINT_2D_IMAGE ], # use input affine
3541 [
36- MetaTensor (DATA_2D , affine = torch . tensor ([[ 2 , 0 , 0 , 0 ], [ 0 , 2 , 0 , 0 ], [ 0 , 0 , 1 , 0 ], [ 0 , 0 , 0 , 1 ]]) ),
42+ MetaTensor (DATA_2D , affine = AFFINE_1 ),
3743 POINT_2D_WORLD ,
3844 None ,
3945 True ,
40- False ,
41- POINT_2D_IMAGE ,
42- ],
46+ True ,
47+ POINT_2D_IMAGE_RAS ,
48+ ], # test affine_lps_to_ras
49+ [MetaTensor (DATA_3D , affine = AFFINE_2 ), POINT_3D_WORLD , None , True , False , POINT_3D_IMAGE ],
50+ ["affine" , POINT_3D_WORLD , None , True , False , POINT_3D_IMAGE ], # use refer_data itself
4351 [
44- None ,
45- MetaTensor (POINT_2D_IMAGE , affine = torch . tensor ([[ 2 , 0 , 0 , 0 ], [ 0 , 2 , 0 , 0 ], [ 0 , 0 , 1 , 0 ], [ 0 , 0 , 0 , 1 ]]) ),
52+ MetaTensor ( DATA_3D , affine = AFFINE_2 ) ,
53+ MetaTensor (POINT_3D_IMAGE , affine = AFFINE_2 ),
4654 None ,
4755 False ,
4856 False ,
49- POINT_2D_WORLD ,
57+ POINT_3D_WORLD ,
5058 ],
59+ [MetaTensor (DATA_3D , affine = AFFINE_2 ), POINT_3D_WORLD , None , True , True , POINT_3D_IMAGE_RAS ],
60+ [MetaTensor (DATA_3D , affine = AFFINE_2 ), POINT_3D_WORLD , None , True , True , POINT_3D_IMAGE_RAS ],
61+ ]
62+ TEST_CASES_SEQUENCE = [
5163 [
64+ (MetaTensor (DATA_2D , affine = AFFINE_1 ), MetaTensor (DATA_3D , affine = AFFINE_2 )),
65+ [POINT_2D_WORLD , POINT_3D_WORLD ],
5266 None ,
53- MetaTensor (POINT_2D_IMAGE , affine = torch .tensor ([[2 , 0 , 0 , 0 ], [0 , 2 , 0 , 0 ], [0 , 0 , 1 , 0 ], [0 , 0 , 0 , 1 ]])),
54- torch .tensor ([[2 , 0 , 0 , 0 ], [0 , 2 , 0 , 0 ], [0 , 0 , 1 , 0 ], [0 , 0 , 0 , 1 ]]),
55- False ,
67+ True ,
5668 False ,
57- POINT_2D_WORLD ,
58- ],
69+ ["image_1" , "image_2" ],
70+ [POINT_2D_IMAGE , POINT_3D_IMAGE ],
71+ ], # use image affine
5972 [
60- MetaTensor (DATA_2D , affine = torch . tensor ([[ 2 , 0 , 0 , 0 ], [ 0 , 2 , 0 , 0 ], [ 0 , 0 , 1 , 0 ], [ 0 , 0 , 0 , 1 ]] )),
61- POINT_2D_WORLD ,
73+ ( MetaTensor (DATA_2D , affine = AFFINE_1 ), MetaTensor ( DATA_3D , affine = AFFINE_2 )),
74+ [ POINT_2D_WORLD , POINT_3D_WORLD ] ,
6275 None ,
6376 True ,
6477 True ,
65- POINT_2D_IMAGE_RAS ,
66- ],
78+ ["image_1" , "image_2" ],
79+ [POINT_2D_IMAGE_RAS , POINT_3D_IMAGE_RAS ],
80+ ], # test affine_lps_to_ras
6781 [
68- MetaTensor ( DATA_3D , affine = torch . tensor ([[ 1 , 0 , 0 , 10 ], [ 0 , 1 , 0 , - 4 ], [ 0 , 0 , 1 , 0 ], [ 0 , 0 , 0 , 1 ]]) ),
69- POINT_3D_WORLD ,
82+ ( None , None ),
83+ [ MetaTensor ( POINT_2D_IMAGE , affine = AFFINE_1 ), MetaTensor ( POINT_3D_IMAGE , affine = AFFINE_2 )] ,
7084 None ,
85+ False ,
86+ False ,
87+ None ,
88+ [POINT_2D_WORLD , POINT_3D_WORLD ],
89+ ], # use point affine
90+ [
91+ (None , None ),
92+ [POINT_2D_WORLD , POINT_2D_WORLD ],
93+ AFFINE_1 ,
7194 True ,
7295 False ,
73- POINT_3D_IMAGE ,
74- ],
75- [ "affine" , POINT_3D_WORLD , None , True , False , POINT_3D_IMAGE ],
96+ None ,
97+ [ POINT_2D_IMAGE , POINT_2D_IMAGE ],
98+ ], # use input affine
7699 [
77- MetaTensor (DATA_3D , affine = torch . tensor ([[ 1 , 0 , 0 , 10 ], [ 0 , 1 , 0 , - 4 ], [ 0 , 0 , 1 , 0 ], [ 0 , 0 , 0 , 1 ]] )),
78- MetaTensor (POINT_3D_IMAGE , affine = torch . tensor ([[ 1 , 0 , 0 , 10 ], [ 0 , 1 , 0 , - 4 ], [ 0 , 0 , 1 , 0 ], [ 0 , 0 , 0 , 1 ]])) ,
100+ ( MetaTensor (DATA_2D , affine = AFFINE_1 ), MetaTensor ( DATA_3D , affine = AFFINE_2 )),
101+ [ MetaTensor (POINT_2D_IMAGE , affine = AFFINE_1 ), MetaTensor ( POINT_3D_IMAGE , affine = AFFINE_2 )] ,
79102 None ,
80103 False ,
81104 False ,
82- POINT_3D_WORLD ,
83- ],
84- [
85- MetaTensor (DATA_3D , affine = torch .tensor ([[1 , 0 , 0 , 10 ], [0 , 1 , 0 , - 4 ], [0 , 0 , 1 , 0 ], [0 , 0 , 0 , 1 ]])),
86- POINT_3D_WORLD ,
87- None ,
88- True ,
89- True ,
90- POINT_3D_IMAGE_RAS ,
105+ ["image_1" , "image_2" ],
106+ [POINT_2D_WORLD , POINT_3D_WORLD ],
91107 ],
92108]
93109
94110TEST_CASES_WRONG = [
95- [POINT_2D_WORLD , True , None ],
96- [POINT_2D_WORLD .unsqueeze (0 ), False , None ],
97- [POINT_3D_WORLD [..., 0 :1 ], False , None ],
98- [POINT_3D_WORLD , False , torch .tensor ([[[1 , 0 , 0 , 10 ], [0 , 1 , 0 , - 4 ], [0 , 0 , 1 , 0 ], [0 , 0 , 0 , 1 ]]])],
111+ [POINT_2D_WORLD , True , None , None ],
112+ [POINT_2D_WORLD .unsqueeze (0 ), False , None , None ],
113+ [POINT_3D_WORLD [..., 0 :1 ], False , None , None ],
114+ [POINT_3D_WORLD , False , torch .tensor ([[[1 , 0 , 0 , 10 ], [0 , 1 , 0 , - 4 ], [0 , 0 , 1 , 0 ], [0 , 0 , 0 , 1 ]]]), None ],
115+ [POINT_3D_WORLD , False , None , "image" ],
116+ [POINT_3D_WORLD , False , None , []],
99117]
100118
101119
@@ -107,10 +125,10 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin
107125 "point" : points ,
108126 "affine" : torch .tensor ([[1 , 0 , 0 , 10 ], [0 , 1 , 0 , - 4 ], [0 , 0 , 1 , 0 ], [0 , 0 , 0 , 1 ]]),
109127 }
110- refer_key = "image" if (image is not None and image != "affine" ) else image
128+ refer_keys = "image" if (image is not None and image != "affine" ) else image
111129 transform = ApplyTransformToPointsd (
112130 keys = "point" ,
113- refer_key = refer_key ,
131+ refer_keys = refer_keys ,
114132 dtype = torch .int64 ,
115133 affine = affine ,
116134 invert_affine = invert_affine ,
@@ -122,11 +140,45 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin
122140 invert_out = transform .inverse (output )
123141 self .assertTrue (torch .allclose (invert_out ["point" ], points ))
124142
143+ @parameterized .expand (TEST_CASES_SEQUENCE )
144+ def test_transform_coordinates_sequences (
145+ self , image , points , affine , invert_affine , affine_lps_to_ras , refer_keys , expected_output
146+ ):
147+ data = {"image_1" : image [0 ], "image_2" : image [1 ], "point_1" : points [0 ], "point_2" : points [1 ]}
148+ keys = ["point_1" , "point_2" ]
149+ transform = ApplyTransformToPointsd (
150+ keys = keys ,
151+ refer_keys = refer_keys ,
152+ dtype = torch .int64 ,
153+ affine = affine ,
154+ invert_affine = invert_affine ,
155+ affine_lps_to_ras = affine_lps_to_ras ,
156+ )
157+ output = transform (data )
158+
159+ self .assertTrue (torch .allclose (output ["point_1" ], expected_output [0 ]))
160+ self .assertTrue (torch .allclose (output ["point_2" ], expected_output [1 ]))
161+ invert_out = transform .inverse (output )
162+ self .assertTrue (torch .allclose (invert_out ["point_1" ], points [0 ]))
163+
125164 @parameterized .expand (TEST_CASES_WRONG )
126- def test_wrong_input (self , input , invert_affine , affine ):
127- transform = ApplyTransformToPointsd (keys = "point" , dtype = torch .int64 , invert_affine = invert_affine , affine = affine )
128- with self .assertRaises (ValueError ):
129- transform ({"point" : input })
165+ def test_wrong_input (self , input , invert_affine , affine , refer_keys ):
166+ if refer_keys == []:
167+ with self .assertRaises (ValueError ):
168+ ApplyTransformToPointsd (
169+ keys = "point" , dtype = torch .int64 , invert_affine = invert_affine , affine = affine , refer_keys = refer_keys
170+ )
171+ else :
172+ transform = ApplyTransformToPointsd (
173+ keys = "point" , dtype = torch .int64 , invert_affine = invert_affine , affine = affine , refer_keys = refer_keys
174+ )
175+ data = {"point" : input }
176+ if refer_keys == "image" :
177+ with self .assertRaises (KeyError ):
178+ transform (data )
179+ else :
180+ with self .assertRaises (ValueError ):
181+ transform (data )
130182
131183
132184if __name__ == "__main__" :
0 commit comments