|
66 | 66 | pytorch_after, |
67 | 67 | ) |
68 | 68 | from monai.utils.enums import TransformBackends |
69 | | -from monai.utils.type_conversion import convert_data_type, convert_to_cupy, convert_to_dst_type, convert_to_tensor |
| 69 | +from monai.utils.type_conversion import ( |
| 70 | + convert_data_type, |
| 71 | + convert_to_cupy, |
| 72 | + convert_to_dst_type, |
| 73 | + convert_to_numpy, |
| 74 | + convert_to_tensor, |
| 75 | +) |
70 | 76 |
|
71 | 77 | measure, has_measure = optional_import("skimage.measure", "0.14.2", min_version) |
72 | 78 | morphology, has_morphology = optional_import("skimage.morphology") |
73 | | -ndimage, _ = optional_import("scipy.ndimage") |
| 79 | +ndimage, has_ndimage = optional_import("scipy.ndimage") |
74 | 80 | cp, has_cp = optional_import("cupy") |
75 | 81 | cp_ndarray, _ = optional_import("cupy", name="ndarray") |
76 | 82 | exposure, has_skimage = optional_import("skimage.exposure") |
|
124 | 130 | "reset_ops_id", |
125 | 131 | "resolves_modes", |
126 | 132 | "has_status_keys", |
| 133 | + "distance_transform_edt", |
127 | 134 | ] |
128 | 135 |
|
129 | 136 |
|
@@ -2051,5 +2058,142 @@ def has_status_keys(data: torch.Tensor, status_key: Any, default_message: str = |
2051 | 2058 | return True, None |
2052 | 2059 |
|
2053 | 2060 |
|
| 2061 | +def distance_transform_edt( |
| 2062 | + img: NdarrayOrTensor, |
| 2063 | + sampling: None | float | list[float] = None, |
| 2064 | + return_distances: bool = True, |
| 2065 | + return_indices: bool = False, |
| 2066 | + distances: NdarrayOrTensor | None = None, |
| 2067 | + indices: NdarrayOrTensor | None = None, |
| 2068 | + *, |
| 2069 | + block_params: tuple[int, int, int] | None = None, |
| 2070 | + float64_distances: bool = False, |
| 2071 | +) -> None | NdarrayOrTensor | tuple[NdarrayOrTensor, NdarrayOrTensor]: |
| 2072 | + """ |
| 2073 | + Euclidean distance transform, either GPU based with CuPy / cuCIM or CPU based with scipy. |
| 2074 | + To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device. |
| 2075 | +
|
| 2076 | + Note that the results of the libraries can differ, so stick to one if possible. |
| 2077 | + For details, check out the `SciPy`_ and `cuCIM`_ documentation. |
| 2078 | +
|
| 2079 | + .. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html |
| 2080 | + .. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt |
| 2081 | +
|
| 2082 | + Args: |
| 2083 | + img: Input image on which the distance transform shall be run. |
| 2084 | + Has to be a channel first array, must have shape: (num_channels, H, W [,D]). |
| 2085 | + Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere. |
| 2086 | + Input gets passed channel-wise to the distance-transform, thus results from this function will differ |
| 2087 | + from directly calling ``distance_transform_edt()`` in CuPy or SciPy. |
| 2088 | + sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1; |
| 2089 | + if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied. |
| 2090 | + return_distances: Whether to calculate the distance transform. |
| 2091 | + return_indices: Whether to calculate the feature transform. |
| 2092 | + distances: An output array to store the calculated distance transform, instead of returning it. |
| 2093 | + `return_distances` must be True. |
| 2094 | + indices: An output array to store the calculated feature transform, instead of returning it. `return_indicies` must be True. |
| 2095 | + block_params: This parameter is specific to cuCIM and does not exist in SciPy. For details, look into `cuCIM`_. |
| 2096 | + float64_distances: This parameter is specific to cuCIM and does not exist in SciPy. |
| 2097 | + If True, use double precision in the distance computation (to match SciPy behavior). |
| 2098 | + Otherwise, single precision will be used for efficiency. |
| 2099 | +
|
| 2100 | + Returns: |
| 2101 | + distances: The calculated distance transform. Returned only when `return_distances` is True and `distances` is not supplied. |
| 2102 | + It will have the same shape as image. For cuCIM: Will have dtype torch.float64 if float64_distances is True, |
| 2103 | + otherwise it will have dtype torch.float32. For SciPy: Will have dtype np.float64. |
| 2104 | + indices: The calculated feature transform. It has an image-shaped array for each dimension of the image. |
| 2105 | + Returned only when `return_indices` is True and `indices` is not supplied. dtype np.float64. |
| 2106 | +
|
| 2107 | + """ |
| 2108 | + distance_transform_edt, has_cucim = optional_import( |
| 2109 | + "cucim.core.operations.morphology", name="distance_transform_edt" |
| 2110 | + ) |
| 2111 | + use_cp = has_cp and has_cucim and isinstance(img, torch.Tensor) and img.device.type == "cuda" |
| 2112 | + |
| 2113 | + if not return_distances and not return_indices: |
| 2114 | + raise RuntimeError("Neither return_distances nor return_indices True") |
| 2115 | + |
| 2116 | + if not (img.ndim >= 3 and img.ndim <= 4): |
| 2117 | + raise RuntimeError("Wrong input dimensionality. Use (num_channels, H, W [,D])") |
| 2118 | + |
| 2119 | + distances_original, indices_original = distances, indices |
| 2120 | + distances, indices = None, None |
| 2121 | + if use_cp: |
| 2122 | + distances_, indices_ = None, None |
| 2123 | + if return_distances: |
| 2124 | + dtype = torch.float64 if float64_distances else torch.float32 |
| 2125 | + if distances is None: |
| 2126 | + distances = torch.zeros_like(img, dtype=dtype) # type: ignore |
| 2127 | + else: |
| 2128 | + if not isinstance(distances, torch.Tensor) and distances.device != img.device: |
| 2129 | + raise TypeError("distances must be a torch.Tensor on the same device as img") |
| 2130 | + if not distances.dtype == dtype: |
| 2131 | + raise TypeError("distances must be a torch.Tensor of dtype float32 or float64") |
| 2132 | + distances_ = convert_to_cupy(distances) |
| 2133 | + if return_indices: |
| 2134 | + dtype = torch.int32 |
| 2135 | + if indices is None: |
| 2136 | + indices = torch.zeros((img.dim(),) + img.shape, dtype=dtype) # type: ignore |
| 2137 | + else: |
| 2138 | + if not isinstance(indices, torch.Tensor) and indices.device != img.device: |
| 2139 | + raise TypeError("indices must be a torch.Tensor on the same device as img") |
| 2140 | + if not indices.dtype == dtype: |
| 2141 | + raise TypeError("indices must be a torch.Tensor of dtype int32") |
| 2142 | + indices_ = convert_to_cupy(indices) |
| 2143 | + img_ = convert_to_cupy(img) |
| 2144 | + for channel_idx in range(img_.shape[0]): |
| 2145 | + distance_transform_edt( |
| 2146 | + img_[channel_idx], |
| 2147 | + sampling=sampling, |
| 2148 | + return_distances=return_distances, |
| 2149 | + return_indices=return_indices, |
| 2150 | + distances=distances_[channel_idx] if distances_ is not None else None, |
| 2151 | + indices=indices_[channel_idx] if indices_ is not None else None, |
| 2152 | + block_params=block_params, |
| 2153 | + float64_distances=float64_distances, |
| 2154 | + ) |
| 2155 | + else: |
| 2156 | + if not has_ndimage: |
| 2157 | + raise RuntimeError("scipy.ndimage required if cupy is not available") |
| 2158 | + img_ = convert_to_numpy(img) |
| 2159 | + if return_distances: |
| 2160 | + if distances is None: |
| 2161 | + distances = np.zeros_like(img_, dtype=np.float64) |
| 2162 | + else: |
| 2163 | + if not isinstance(distances, np.ndarray): |
| 2164 | + raise TypeError("distances must be a numpy.ndarray") |
| 2165 | + if not distances.dtype == np.float64: |
| 2166 | + raise TypeError("distances must be a numpy.ndarray of dtype float64") |
| 2167 | + if return_indices: |
| 2168 | + if indices is None: |
| 2169 | + indices = np.zeros((img_.ndim,) + img_.shape, dtype=np.int32) |
| 2170 | + else: |
| 2171 | + if not isinstance(indices, np.ndarray): |
| 2172 | + raise TypeError("indices must be a numpy.ndarray") |
| 2173 | + if not indices.dtype == np.int32: |
| 2174 | + raise TypeError("indices must be a numpy.ndarray of dtype int32") |
| 2175 | + |
| 2176 | + for channel_idx in range(img_.shape[0]): |
| 2177 | + ndimage.distance_transform_edt( |
| 2178 | + img_[channel_idx], |
| 2179 | + sampling=sampling, |
| 2180 | + return_distances=return_distances, |
| 2181 | + return_indices=return_indices, |
| 2182 | + distances=distances[channel_idx] if distances is not None else None, |
| 2183 | + indices=indices[channel_idx] if indices is not None else None, |
| 2184 | + ) |
| 2185 | + |
| 2186 | + r_vals = [] |
| 2187 | + if return_distances and distances_original is None: |
| 2188 | + r_vals.append(distances) |
| 2189 | + if return_indices and indices_original is None: |
| 2190 | + r_vals.append(indices) |
| 2191 | + if not r_vals: |
| 2192 | + return None |
| 2193 | + if len(r_vals) == 1: |
| 2194 | + return r_vals[0] |
| 2195 | + return tuple(r_vals) # type: ignore |
| 2196 | + |
| 2197 | + |
2054 | 2198 | if __name__ == "__main__": |
2055 | 2199 | print_transform_backends() |
0 commit comments