Skip to content

Commit 4e70bf6

Browse files
KumoLiuericspodpre-commit-ci[bot]
authored
Allow ApplyTransformToPointsd receive a sequence of refer keys (#8063)
Enhance `ApplyTransformToPointsd` to receive a sequence of refer keys. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent aea46ff commit 4e70bf6

3 files changed

Lines changed: 146 additions & 82 deletions

File tree

monai/transforms/utility/array.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,6 +1764,30 @@ def __init__(
17641764
self.invert_affine = invert_affine
17651765
self.affine_lps_to_ras = affine_lps_to_ras
17661766

1767+
def _compute_final_affine(self, affine: torch.Tensor, applied_affine: torch.Tensor | None = None) -> torch.Tensor:
1768+
"""
1769+
Compute the final affine transformation matrix to apply to the point data.
1770+
1771+
Args:
1772+
data: Input coordinates assumed to be in the shape (C, N, 2 or 3).
1773+
affine: 3x3 or 4x4 affine transformation matrix.
1774+
1775+
Returns:
1776+
Final affine transformation matrix.
1777+
"""
1778+
1779+
affine = convert_data_type(affine, dtype=torch.float64)[0]
1780+
1781+
if self.affine_lps_to_ras:
1782+
affine = orientation_ras_lps(affine)
1783+
1784+
if self.invert_affine:
1785+
affine = linalg_inv(affine)
1786+
if applied_affine is not None:
1787+
affine = affine @ applied_affine
1788+
1789+
return affine
1790+
17671791
def transform_coordinates(
17681792
self, data: torch.Tensor, affine: torch.Tensor | None = None
17691793
) -> tuple[torch.Tensor, dict]:
@@ -1780,35 +1804,25 @@ def transform_coordinates(
17801804
Transformed coordinates.
17811805
"""
17821806
data = convert_to_tensor(data, track_meta=get_track_meta())
1783-
# applied_affine is the affine transformation matrix that has already been applied to the point data
1784-
applied_affine = getattr(data, "affine", None)
1785-
17861807
if affine is None and self.invert_affine:
17871808
raise ValueError("affine must be provided when invert_affine is True.")
1788-
1809+
# applied_affine is the affine transformation matrix that has already been applied to the point data
1810+
applied_affine: torch.Tensor | None = getattr(data, "affine", None)
17891811
affine = applied_affine if affine is None else affine
1790-
affine = convert_data_type(affine, dtype=torch.float64)[0] # always convert to float64 for affine
1791-
original_affine: torch.Tensor = affine
1792-
if self.affine_lps_to_ras:
1793-
affine = orientation_ras_lps(affine)
1812+
if affine is None:
1813+
raise ValueError("affine must be provided if data does not have an affine matrix.")
17941814

1795-
# the final affine transformation matrix that will be applied to the point data
1796-
_affine: torch.Tensor = affine
1797-
if self.invert_affine:
1798-
_affine = linalg_inv(affine)
1799-
if applied_affine is not None:
1800-
# consider the affine transformation already applied to the data in the world space
1801-
# and compute delta affine
1802-
_affine = _affine @ linalg_inv(applied_affine)
1803-
out = apply_affine_to_points(data, _affine, dtype=self.dtype)
1815+
final_affine = self._compute_final_affine(affine, applied_affine)
1816+
out = apply_affine_to_points(data, final_affine, dtype=self.dtype)
18041817

18051818
extra_info = {
18061819
"invert_affine": self.invert_affine,
18071820
"dtype": get_dtype_string(self.dtype),
1808-
"image_affine": original_affine, # record for inverse operation
1821+
"image_affine": affine,
18091822
"affine_lps_to_ras": self.affine_lps_to_ras,
18101823
}
1811-
xform: torch.Tensor = original_affine if self.invert_affine else linalg_inv(original_affine)
1824+
1825+
xform = orientation_ras_lps(linalg_inv(final_affine)) if self.affine_lps_to_ras else linalg_inv(final_affine)
18121826
meta_info = TraceableTransform.track_transform_meta(
18131827
data, affine=xform, extra_info=extra_info, transform_info=self.get_transform_info()
18141828
)
@@ -1834,16 +1848,12 @@ def __call__(self, data: torch.Tensor, affine: torch.Tensor | None = None):
18341848

18351849
def inverse(self, data: torch.Tensor) -> torch.Tensor:
18361850
transform = self.pop_transform(data)
1837-
# Create inverse transform
1838-
dtype = transform[TraceKeys.EXTRA_INFO]["dtype"]
1839-
invert_affine = not transform[TraceKeys.EXTRA_INFO]["invert_affine"]
1840-
affine = transform[TraceKeys.EXTRA_INFO]["image_affine"]
1841-
affine_lps_to_ras = transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"]
18421851
inverse_transform = ApplyTransformToPoints(
1843-
dtype=dtype, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
1852+
dtype=transform[TraceKeys.EXTRA_INFO]["dtype"],
1853+
invert_affine=not transform[TraceKeys.EXTRA_INFO]["invert_affine"],
1854+
affine_lps_to_ras=transform[TraceKeys.EXTRA_INFO]["affine_lps_to_ras"],
18441855
)
1845-
# Apply inverse
18461856
with inverse_transform.trace_transform(False):
1847-
data = inverse_transform(data, affine)
1857+
data = inverse_transform(data, transform[TraceKeys.EXTRA_INFO]["image_affine"])
18481858

18491859
return data

monai/transforms/utility/dictionary.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,8 +1758,9 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
17581758
Args:
17591759
keys: keys of the corresponding items to be transformed.
17601760
See also: monai.transforms.MapTransform
1761-
refer_key: The key of the reference item used for transformation.
1762-
It can directly refer to an affine or an image from which the affine can be derived.
1761+
refer_keys: The key of the reference item used for transformation.
1762+
It can directly refer to an affine or an image from which the affine can be derived. It can also be a
1763+
sequence of keys, in which case each refers to the affine applied to the matching points in `keys`.
17631764
dtype: The desired data type for the output.
17641765
affine: A 3x3 or 4x4 affine transformation matrix applied to points. This matrix typically originates
17651766
from the image. For 2D points, a 3x3 matrix can be provided, avoiding the need to add an unnecessary
@@ -1782,31 +1783,32 @@ class ApplyTransformToPointsd(MapTransform, InvertibleTransform):
17821783
def __init__(
17831784
self,
17841785
keys: KeysCollection,
1785-
refer_key: str | None = None,
1786+
refer_keys: KeysCollection | None = None,
17861787
dtype: DtypeLike | torch.dtype = torch.float64,
17871788
affine: torch.Tensor | None = None,
17881789
invert_affine: bool = True,
17891790
affine_lps_to_ras: bool = False,
17901791
allow_missing_keys: bool = False,
17911792
):
17921793
MapTransform.__init__(self, keys, allow_missing_keys)
1793-
self.refer_key = refer_key
1794+
self.refer_keys = ensure_tuple_rep(refer_keys, len(self.keys))
17941795
self.converter = ApplyTransformToPoints(
17951796
dtype=dtype, affine=affine, invert_affine=invert_affine, affine_lps_to_ras=affine_lps_to_ras
17961797
)
17971798

17981799
def __call__(self, data: Mapping[Hashable, torch.Tensor]):
17991800
d = dict(data)
1800-
if self.refer_key is not None:
1801-
if self.refer_key in d:
1802-
refer_data = d[self.refer_key]
1803-
else:
1804-
raise KeyError(f"The refer_key '{self.refer_key}' is not found in the data.")
1805-
else:
1806-
refer_data = None
1807-
affine = getattr(refer_data, "affine", refer_data)
1808-
for key in self.key_iterator(d):
1801+
for key, refer_key in self.key_iterator(d, self.refer_keys):
18091802
coords = d[key]
1803+
affine = None # represents using affine given in constructor
1804+
if refer_key is not None:
1805+
if refer_key in d:
1806+
refer_data = d[refer_key]
1807+
else:
1808+
raise KeyError(f"The refer_key '{refer_key}' is not found in the data.")
1809+
1810+
# use the "affine" member of refer_data, or refer_data itself, as the affine matrix
1811+
affine = getattr(refer_data, "affine", refer_data)
18101812
d[key] = self.converter(coords, affine)
18111813
return d
18121814

tests/test_apply_transform_to_pointsd.py

Lines changed: 94 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -30,72 +30,90 @@
3030
POINT_3D_WORLD = torch.tensor([[[2, 4, 6], [8, 10, 12]], [[14, 16, 18], [20, 22, 24]]])
3131
POINT_3D_IMAGE = torch.tensor([[[-8, 8, 6], [-2, 14, 12]], [[4, 20, 18], [10, 26, 24]]])
3232
POINT_3D_IMAGE_RAS = torch.tensor([[[-12, 0, 6], [-18, -6, 12]], [[-24, -12, 18], [-30, -18, 24]]])
33+
AFFINE_1 = torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])
34+
AFFINE_2 = torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])
3335

3436
TEST_CASES = [
37+
[MetaTensor(DATA_2D, affine=AFFINE_1), POINT_2D_WORLD, None, True, False, POINT_2D_IMAGE], # use image affine
38+
[None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), None, False, False, POINT_2D_WORLD], # use point affine
39+
[None, MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), AFFINE_1, False, False, POINT_2D_WORLD], # use input affine
40+
[None, POINT_2D_WORLD, AFFINE_1, True, False, POINT_2D_IMAGE], # use input affine
3541
[
36-
MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
42+
MetaTensor(DATA_2D, affine=AFFINE_1),
3743
POINT_2D_WORLD,
3844
None,
3945
True,
40-
False,
41-
POINT_2D_IMAGE,
42-
],
46+
True,
47+
POINT_2D_IMAGE_RAS,
48+
], # test affine_lps_to_ras
49+
[MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE],
50+
["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE], # use refer_data itself
4351
[
44-
None,
45-
MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
52+
MetaTensor(DATA_3D, affine=AFFINE_2),
53+
MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2),
4654
None,
4755
False,
4856
False,
49-
POINT_2D_WORLD,
57+
POINT_3D_WORLD,
5058
],
59+
[MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS],
60+
[MetaTensor(DATA_3D, affine=AFFINE_2), POINT_3D_WORLD, None, True, True, POINT_3D_IMAGE_RAS],
61+
]
62+
TEST_CASES_SEQUENCE = [
5163
[
64+
(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
65+
[POINT_2D_WORLD, POINT_3D_WORLD],
5266
None,
53-
MetaTensor(POINT_2D_IMAGE, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
54-
torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]),
55-
False,
67+
True,
5668
False,
57-
POINT_2D_WORLD,
58-
],
69+
["image_1", "image_2"],
70+
[POINT_2D_IMAGE, POINT_3D_IMAGE],
71+
], # use image affine
5972
[
60-
MetaTensor(DATA_2D, affine=torch.tensor([[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])),
61-
POINT_2D_WORLD,
73+
(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
74+
[POINT_2D_WORLD, POINT_3D_WORLD],
6275
None,
6376
True,
6477
True,
65-
POINT_2D_IMAGE_RAS,
66-
],
78+
["image_1", "image_2"],
79+
[POINT_2D_IMAGE_RAS, POINT_3D_IMAGE_RAS],
80+
], # test affine_lps_to_ras
6781
[
68-
MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
69-
POINT_3D_WORLD,
82+
(None, None),
83+
[MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)],
7084
None,
85+
False,
86+
False,
87+
None,
88+
[POINT_2D_WORLD, POINT_3D_WORLD],
89+
], # use point affine
90+
[
91+
(None, None),
92+
[POINT_2D_WORLD, POINT_2D_WORLD],
93+
AFFINE_1,
7194
True,
7295
False,
73-
POINT_3D_IMAGE,
74-
],
75-
["affine", POINT_3D_WORLD, None, True, False, POINT_3D_IMAGE],
96+
None,
97+
[POINT_2D_IMAGE, POINT_2D_IMAGE],
98+
], # use input affine
7699
[
77-
MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
78-
MetaTensor(POINT_3D_IMAGE, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
100+
(MetaTensor(DATA_2D, affine=AFFINE_1), MetaTensor(DATA_3D, affine=AFFINE_2)),
101+
[MetaTensor(POINT_2D_IMAGE, affine=AFFINE_1), MetaTensor(POINT_3D_IMAGE, affine=AFFINE_2)],
79102
None,
80103
False,
81104
False,
82-
POINT_3D_WORLD,
83-
],
84-
[
85-
MetaTensor(DATA_3D, affine=torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]])),
86-
POINT_3D_WORLD,
87-
None,
88-
True,
89-
True,
90-
POINT_3D_IMAGE_RAS,
105+
["image_1", "image_2"],
106+
[POINT_2D_WORLD, POINT_3D_WORLD],
91107
],
92108
]
93109

94110
TEST_CASES_WRONG = [
95-
[POINT_2D_WORLD, True, None],
96-
[POINT_2D_WORLD.unsqueeze(0), False, None],
97-
[POINT_3D_WORLD[..., 0:1], False, None],
98-
[POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]])],
111+
[POINT_2D_WORLD, True, None, None],
112+
[POINT_2D_WORLD.unsqueeze(0), False, None, None],
113+
[POINT_3D_WORLD[..., 0:1], False, None, None],
114+
[POINT_3D_WORLD, False, torch.tensor([[[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]]), None],
115+
[POINT_3D_WORLD, False, None, "image"],
116+
[POINT_3D_WORLD, False, None, []],
99117
]
100118

101119

@@ -107,10 +125,10 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin
107125
"point": points,
108126
"affine": torch.tensor([[1, 0, 0, 10], [0, 1, 0, -4], [0, 0, 1, 0], [0, 0, 0, 1]]),
109127
}
110-
refer_key = "image" if (image is not None and image != "affine") else image
128+
refer_keys = "image" if (image is not None and image != "affine") else image
111129
transform = ApplyTransformToPointsd(
112130
keys="point",
113-
refer_key=refer_key,
131+
refer_keys=refer_keys,
114132
dtype=torch.int64,
115133
affine=affine,
116134
invert_affine=invert_affine,
@@ -122,11 +140,45 @@ def test_transform_coordinates(self, image, points, affine, invert_affine, affin
122140
invert_out = transform.inverse(output)
123141
self.assertTrue(torch.allclose(invert_out["point"], points))
124142

143+
@parameterized.expand(TEST_CASES_SEQUENCE)
144+
def test_transform_coordinates_sequences(
145+
self, image, points, affine, invert_affine, affine_lps_to_ras, refer_keys, expected_output
146+
):
147+
data = {"image_1": image[0], "image_2": image[1], "point_1": points[0], "point_2": points[1]}
148+
keys = ["point_1", "point_2"]
149+
transform = ApplyTransformToPointsd(
150+
keys=keys,
151+
refer_keys=refer_keys,
152+
dtype=torch.int64,
153+
affine=affine,
154+
invert_affine=invert_affine,
155+
affine_lps_to_ras=affine_lps_to_ras,
156+
)
157+
output = transform(data)
158+
159+
self.assertTrue(torch.allclose(output["point_1"], expected_output[0]))
160+
self.assertTrue(torch.allclose(output["point_2"], expected_output[1]))
161+
invert_out = transform.inverse(output)
162+
self.assertTrue(torch.allclose(invert_out["point_1"], points[0]))
163+
125164
@parameterized.expand(TEST_CASES_WRONG)
126-
def test_wrong_input(self, input, invert_affine, affine):
127-
transform = ApplyTransformToPointsd(keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine)
128-
with self.assertRaises(ValueError):
129-
transform({"point": input})
165+
def test_wrong_input(self, input, invert_affine, affine, refer_keys):
166+
if refer_keys == []:
167+
with self.assertRaises(ValueError):
168+
ApplyTransformToPointsd(
169+
keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys
170+
)
171+
else:
172+
transform = ApplyTransformToPointsd(
173+
keys="point", dtype=torch.int64, invert_affine=invert_affine, affine=affine, refer_keys=refer_keys
174+
)
175+
data = {"point": input}
176+
if refer_keys == "image":
177+
with self.assertRaises(KeyError):
178+
transform(data)
179+
else:
180+
with self.assertRaises(ValueError):
181+
transform(data)
130182

131183

132184
if __name__ == "__main__":

0 commit comments

Comments
 (0)