Skip to content

Commit 84566d1

Browse files
authored
Implement distance_transform_edt and the DistanceTransformEDT transform (#6981)
Related to #6845, this commits adds an EDT distance transform to MONAI. Most importantly this enables GPU based distance transforms which lead to a huge speedup. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] In-line docstrings updated. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Matthias Hadlich <matthiashadlich@posteo.de>
1 parent a29ab04 commit 84566d1

7 files changed

Lines changed: 453 additions & 2 deletions

File tree

docs/source/transforms.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -602,6 +602,12 @@ Post-processing
602602
:members:
603603
:special-members: __call__
604604

605+
`DistanceTransformEDT`
606+
"""""""""""""""""""""""""""""""
607+
.. autoclass:: DistanceTransformEDT
608+
:members:
609+
:special-members: __call__
610+
605611
`RemoveSmallObjects`
606612
""""""""""""""""""""
607613
.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RemoveSmallObjects.png
@@ -1622,6 +1628,12 @@ Post-processing (Dict)
16221628
:members:
16231629
:special-members: __call__
16241630

1631+
`DistanceTransformEDTd`
1632+
""""""""""""""""""""""""""""""""
1633+
.. autoclass:: DistanceTransformEDTd
1634+
:members:
1635+
:special-members: __call__
1636+
16251637
`RemoveSmallObjectsd`
16261638
"""""""""""""""""""""
16271639
.. image:: https://raw.githubusercontent.com/Project-MONAI/DocImages/main/transforms/RemoveSmallObjectsd.png

monai/transforms/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@
277277
from .post.array import (
278278
Activations,
279279
AsDiscrete,
280+
DistanceTransformEDT,
280281
FillHoles,
281282
Invert,
282283
KeepLargestConnectedComponent,
@@ -295,6 +296,9 @@
295296
AsDiscreteD,
296297
AsDiscreted,
297298
AsDiscreteDict,
299+
DistanceTransformEDTd,
300+
DistanceTransformEDTD,
301+
DistanceTransformEDTDict,
298302
Ensembled,
299303
EnsembleD,
300304
EnsembleDict,

monai/transforms/post/array.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from monai.transforms.utility.array import ToTensor
3232
from monai.transforms.utils import (
3333
convert_applied_interp_mode,
34+
distance_transform_edt,
3435
fill_holes,
3536
get_largest_connected_component_mask,
3637
get_unique_labels,
@@ -53,6 +54,7 @@
5354
"SobelGradients",
5455
"VoteEnsemble",
5556
"Invert",
57+
"DistanceTransformEDT",
5658
]
5759

5860

@@ -936,3 +938,39 @@ def __call__(self, image: NdarrayOrTensor) -> torch.Tensor:
936938
grads = convert_to_dst_type(grads.squeeze(0), image_tensor)[0]
937939

938940
return grads
941+
942+
943+
class DistanceTransformEDT(Transform):
944+
"""
945+
Applies the Euclidean distance transform on the input.
946+
Either GPU based with CuPy / cuCIM or CPU based with scipy.
947+
To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.
948+
949+
Note that the results of the libraries can differ, so stick to one if possible.
950+
For details, check out the `SciPy`_ and `cuCIM`_ documentation and / or :func:`monai.transforms.utils.distance_transform_edt`.
951+
952+
.. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html
953+
.. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt
954+
"""
955+
956+
backend = [TransformBackends.NUMPY, TransformBackends.CUPY]
957+
958+
def __init__(self, sampling: None | float | list[float] = None) -> None:
959+
super().__init__()
960+
self.sampling = sampling
961+
962+
def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
963+
"""
964+
Args:
965+
img: Input image on which the distance transform shall be run.
966+
Has to be a channel first array, must have shape: (num_channels, H, W [,D]).
967+
Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.
968+
Input gets passed channel-wise to the distance-transform, thus results from this function will differ
969+
from directly calling ``distance_transform_edt()`` in CuPy or SciPy.
970+
sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;
971+
if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.
972+
973+
Returns:
974+
An array with the same shape and data type as img
975+
"""
976+
return distance_transform_edt(img=img, sampling=self.sampling) # type: ignore

monai/transforms/post/dictionary.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from monai.transforms.post.array import (
3434
Activations,
3535
AsDiscrete,
36+
DistanceTransformEDT,
3637
FillHoles,
3738
KeepLargestConnectedComponent,
3839
LabelFilter,
@@ -91,6 +92,9 @@
9192
"VoteEnsembleD",
9293
"VoteEnsembleDict",
9394
"VoteEnsembled",
95+
"DistanceTransformEDTd",
96+
"DistanceTransformEDTD",
97+
"DistanceTransformEDTDict",
9498
]
9599

96100
DEFAULT_POST_FIX = PostFix.meta()
@@ -855,6 +859,51 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
855859
return d
856860

857861

862+
class DistanceTransformEDTd(MapTransform):
863+
"""
864+
Applies the Euclidean distance transform on the input.
865+
Either GPU based with CuPy / cuCIM or CPU based with scipy.
866+
To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.
867+
868+
Note that the results of the libraries can differ, so stick to one if possible.
869+
For details, check out the `SciPy`_ and `cuCIM`_ documentation and / or :func:`monai.transforms.utils.distance_transform_edt`.
870+
871+
872+
Note on the input shape:
873+
Has to be a channel first array, must have shape: (num_channels, H, W [,D]).
874+
Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.
875+
Input gets passed channel-wise to the distance-transform, thus results from this function will differ
876+
from directly calling ``distance_transform_edt()`` in CuPy or SciPy.
877+
878+
Args:
879+
keys: keys of the corresponding items to be transformed.
880+
allow_missing_keys: don't raise exception if key is missing.
881+
sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;
882+
if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.
883+
884+
.. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html
885+
.. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt
886+
887+
888+
"""
889+
890+
backend = DistanceTransformEDT.backend
891+
892+
def __init__(
893+
self, keys: KeysCollection, allow_missing_keys: bool = False, sampling: None | float | list[float] = None
894+
) -> None:
895+
super().__init__(keys, allow_missing_keys)
896+
self.sampling = sampling
897+
self.distance_transform = DistanceTransformEDT(sampling=self.sampling)
898+
899+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
900+
d = dict(data)
901+
for key in self.key_iterator(d):
902+
d[key] = self.distance_transform(img=d[key])
903+
904+
return d
905+
906+
858907
ActivationsD = ActivationsDict = Activationsd
859908
AsDiscreteD = AsDiscreteDict = AsDiscreted
860909
FillHolesD = FillHolesDict = FillHolesd
@@ -869,3 +918,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
869918
VoteEnsembleD = VoteEnsembleDict = VoteEnsembled
870919
EnsembleD = EnsembleDict = Ensembled
871920
SobelGradientsD = SobelGradientsDict = SobelGradientsd
921+
DistanceTransformEDTD = DistanceTransformEDTDict = DistanceTransformEDTd

monai/transforms/utils.py

Lines changed: 146 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,17 @@
6666
pytorch_after,
6767
)
6868
from monai.utils.enums import TransformBackends
69-
from monai.utils.type_conversion import convert_data_type, convert_to_cupy, convert_to_dst_type, convert_to_tensor
69+
from monai.utils.type_conversion import (
70+
convert_data_type,
71+
convert_to_cupy,
72+
convert_to_dst_type,
73+
convert_to_numpy,
74+
convert_to_tensor,
75+
)
7076

7177
measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version)
7278
morphology, has_morphology = optional_import("skimage.morphology")
73-
ndimage, _ = optional_import("scipy.ndimage")
79+
ndimage, has_ndimage = optional_import("scipy.ndimage")
7480
cp, has_cp = optional_import("cupy")
7581
cp_ndarray, _ = optional_import("cupy", name="ndarray")
7682
exposure, has_skimage = optional_import("skimage.exposure")
@@ -124,6 +130,7 @@
124130
"reset_ops_id",
125131
"resolves_modes",
126132
"has_status_keys",
133+
"distance_transform_edt",
127134
]
128135

129136

@@ -2051,5 +2058,142 @@ def has_status_keys(data: torch.Tensor, status_key: Any, default_message: str =
20512058
return True, None
20522059

20532060

2061+
def distance_transform_edt(
2062+
img: NdarrayOrTensor,
2063+
sampling: None | float | list[float] = None,
2064+
return_distances: bool = True,
2065+
return_indices: bool = False,
2066+
distances: NdarrayOrTensor | None = None,
2067+
indices: NdarrayOrTensor | None = None,
2068+
*,
2069+
block_params: tuple[int, int, int] | None = None,
2070+
float64_distances: bool = False,
2071+
) -> None | NdarrayOrTensor | tuple[NdarrayOrTensor, NdarrayOrTensor]:
2072+
"""
2073+
Euclidean distance transform, either GPU based with CuPy / cuCIM or CPU based with scipy.
2074+
To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.
2075+
2076+
Note that the results of the libraries can differ, so stick to one if possible.
2077+
For details, check out the `SciPy`_ and `cuCIM`_ documentation.
2078+
2079+
.. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html
2080+
.. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt
2081+
2082+
Args:
2083+
img: Input image on which the distance transform shall be run.
2084+
Has to be a channel first array, must have shape: (num_channels, H, W [,D]).
2085+
Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.
2086+
Input gets passed channel-wise to the distance-transform, thus results from this function will differ
2087+
from directly calling ``distance_transform_edt()`` in CuPy or SciPy.
2088+
sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;
2089+
if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.
2090+
return_distances: Whether to calculate the distance transform.
2091+
return_indices: Whether to calculate the feature transform.
2092+
distances: An output array to store the calculated distance transform, instead of returning it.
2093+
`return_distances` must be True.
2094+
indices: An output array to store the calculated feature transform, instead of returning it. `return_indicies` must be True.
2095+
block_params: This parameter is specific to cuCIM and does not exist in SciPy. For details, look into `cuCIM`_.
2096+
float64_distances: This parameter is specific to cuCIM and does not exist in SciPy.
2097+
If True, use double precision in the distance computation (to match SciPy behavior).
2098+
Otherwise, single precision will be used for efficiency.
2099+
2100+
Returns:
2101+
distances: The calculated distance transform. Returned only when `return_distances` is True and `distances` is not supplied.
2102+
It will have the same shape as image. For cuCIM: Will have dtype torch.float64 if float64_distances is True,
2103+
otherwise it will have dtype torch.float32. For SciPy: Will have dtype np.float64.
2104+
indices: The calculated feature transform. It has an image-shaped array for each dimension of the image.
2105+
Returned only when `return_indices` is True and `indices` is not supplied. dtype np.float64.
2106+
2107+
"""
2108+
distance_transform_edt, has_cucim = optional_import(
2109+
"cucim.core.operations.morphology", name="distance_transform_edt"
2110+
)
2111+
use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device.type == "cuda"
2112+
2113+
if not return_distances and not return_indices:
2114+
raise RuntimeError("Neither return_distances nor return_indices True")
2115+
2116+
if not (img.ndim >= 3 and img.ndim <= 4):
2117+
raise RuntimeError("Wrong input dimensionality. Use (num_channels, H, W [,D])")
2118+
2119+
distances_original, indices_original = distances, indices
2120+
distances, indices = None, None
2121+
if use_cp:
2122+
distances_, indices_ = None, None
2123+
if return_distances:
2124+
dtype = torch.float64 if float64_distances else torch.float32
2125+
if distances is None:
2126+
distances = torch.zeros_like(img, dtype=dtype) # type: ignore
2127+
else:
2128+
if not isinstance(distances, torch.Tensor) and distances.device != img.device:
2129+
raise TypeError("distances must be a torch.Tensor on the same device as img")
2130+
if not distances.dtype == dtype:
2131+
raise TypeError("distances must be a torch.Tensor of dtype float32 or float64")
2132+
distances_ = convert_to_cupy(distances)
2133+
if return_indices:
2134+
dtype = torch.int32
2135+
if indices is None:
2136+
indices = torch.zeros((img.dim(),) + img.shape, dtype=dtype) # type: ignore
2137+
else:
2138+
if not isinstance(indices, torch.Tensor) and indices.device != img.device:
2139+
raise TypeError("indices must be a torch.Tensor on the same device as img")
2140+
if not indices.dtype == dtype:
2141+
raise TypeError("indices must be a torch.Tensor of dtype int32")
2142+
indices_ = convert_to_cupy(indices)
2143+
img_ = convert_to_cupy(img)
2144+
for channel_idx in range(img_.shape[0]):
2145+
distance_transform_edt(
2146+
img_[channel_idx],
2147+
sampling=sampling,
2148+
return_distances=return_distances,
2149+
return_indices=return_indices,
2150+
distances=distances_[channel_idx] if distances_ is not None else None,
2151+
indices=indices_[channel_idx] if indices_ is not None else None,
2152+
block_params=block_params,
2153+
float64_distances=float64_distances,
2154+
)
2155+
else:
2156+
if not has_ndimage:
2157+
raise RuntimeError("scipy.ndimage required if cupy is not available")
2158+
img_ = convert_to_numpy(img)
2159+
if return_distances:
2160+
if distances is None:
2161+
distances = np.zeros_like(img_, dtype=np.float64)
2162+
else:
2163+
if not isinstance(distances, np.ndarray):
2164+
raise TypeError("distances must be a numpy.ndarray")
2165+
if not distances.dtype == np.float64:
2166+
raise TypeError("distances must be a numpy.ndarray of dtype float64")
2167+
if return_indices:
2168+
if indices is None:
2169+
indices = np.zeros((img_.ndim,) + img_.shape, dtype=np.int32)
2170+
else:
2171+
if not isinstance(indices, np.ndarray):
2172+
raise TypeError("indices must be a numpy.ndarray")
2173+
if not indices.dtype == np.int32:
2174+
raise TypeError("indices must be a numpy.ndarray of dtype int32")
2175+
2176+
for channel_idx in range(img_.shape[0]):
2177+
ndimage.distance_transform_edt(
2178+
img_[channel_idx],
2179+
sampling=sampling,
2180+
return_distances=return_distances,
2181+
return_indices=return_indices,
2182+
distances=distances[channel_idx] if distances is not None else None,
2183+
indices=indices[channel_idx] if indices is not None else None,
2184+
)
2185+
2186+
r_vals = []
2187+
if return_distances and distances_original is None:
2188+
r_vals.append(distances)
2189+
if return_indices and indices_original is None:
2190+
r_vals.append(indices)
2191+
if not r_vals:
2192+
return None
2193+
if len(r_vals) == 1:
2194+
return r_vals[0]
2195+
return tuple(r_vals) # type: ignore
2196+
2197+
20542198
if __name__ == "__main__":
20552199
print_transform_backends()

tests/min_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def run_testsuit():
6161
"test_deepgrow_transforms",
6262
"test_detect_envelope",
6363
"test_dints_network",
64+
"test_distance_transform_edt",
6465
"test_efficientnet",
6566
"test_ensemble_evaluator",
6667
"test_ensure_channel_first",

0 commit comments

Comments
 (0)