Skip to content

Commit f140e06

Browse files
Cucim support for get_mask_edges and get_surface_distance (#7008)
### Description Add support for cucim for `get_mask_edges` and `get_surface_distance`. This provides significant speedup in surface related metrics. Profiling on my system gave 3-20x speedups depending on the input shape: [---------------- (250, 250, 250) -----------------] | cpu | cuda 1 threads: ----------------------------------------- random()>0.2 | 26400.8 | 1306.3 random()>0.5 | 26411.8 | 1399.1 random()>0.8 | 29993.2 | 1009.5 create_spherical_seg_3d | 623.8 | 45.0 Times are in milliseconds (ms). [--------------- (100, 100, 100) ----------------] | cpu | cuda 1 threads: --------------------------------------- random()>0.2 | 1332.5 | 140.2 random()>0.5 | 1276.3 | 128.1 random()>0.8 | 1179.2 | 89.1 create_spherical_seg_3d | 111.7 | 44.0 Times are in milliseconds (ms). [---------------- (50, 50, 50) ----------------] | cpu | cuda 1 threads: ------------------------------------- random()>0.2 | 154.5 | 47.4 random()>0.5 | 166.7 | 39.3 random()>0.8 | 165.0 | 38.0 create_spherical_seg_3d | 77.2 | 44.4 Times are in milliseconds (ms). where create_spherical_seg_3d uses the same function from test_hausdorff_distance, and binarizes random array using `random(shape)>ratio`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: John Zielke <john.zielke@snkeos.com>
1 parent 21028ee commit f140e06

9 files changed

Lines changed: 262 additions & 147 deletions

File tree

monai/losses/hausdorff_loss.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,11 @@
1919
import warnings
2020
from typing import Callable
2121

22-
import numpy as np
2322
import torch
2423
from torch.nn.modules.loss import _Loss
2524

26-
from monai.metrics.utils import distance_transform_edt
2725
from monai.networks import one_hot
26+
from monai.transforms.utils import distance_transform_edt
2827
from monai.utils import LossReduction
2928

3029

@@ -95,7 +94,7 @@ def __init__(
9594
self.batch = batch
9695

9796
@torch.no_grad()
98-
def distance_field(self, img: np.ndarray) -> np.ndarray:
97+
def distance_field(self, img: torch.Tensor) -> torch.Tensor:
9998
"""Generate distance transform.
10099
101100
Args:
@@ -104,18 +103,20 @@ def distance_field(self, img: np.ndarray) -> np.ndarray:
104103
Returns:
105104
np.ndarray: Distance field.
106105
"""
107-
field = np.zeros_like(img)
106+
field = torch.zeros_like(img)
108107

109-
for batch in range(len(img)):
110-
fg_mask = img[batch] > 0.5
108+
for batch_idx in range(len(img)):
109+
fg_mask = img[batch_idx] > 0.5
111110

112-
if fg_mask.any():
111+
# For cases where the mask is entirely background or entirely foreground
112+
# the distance transform is not well defined for all 1s,
113+
# which always would happen on either foreground or background, so skip
114+
if fg_mask.any() and not fg_mask.all():
115+
fg_dist: torch.Tensor = distance_transform_edt(fg_mask) # type: ignore
113116
bg_mask = ~fg_mask
117+
bg_dist: torch.Tensor = distance_transform_edt(bg_mask) # type: ignore
114118

115-
fg_dist = distance_transform_edt(fg_mask)
116-
bg_dist = distance_transform_edt(bg_mask)
117-
118-
field[batch] = fg_dist + bg_dist
119+
field[batch_idx] = fg_dist + bg_dist
119120

120121
return field
121122

@@ -181,8 +182,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
181182
for i in range(input.shape[1]):
182183
ch_input = input[:, [i]]
183184
ch_target = target[:, [i]]
184-
pred_dt = torch.from_numpy(self.distance_field(ch_input.detach().cpu().numpy())).float()
185-
target_dt = torch.from_numpy(self.distance_field(ch_target.detach().cpu().numpy())).float()
185+
pred_dt = self.distance_field(ch_input.detach()).float()
186+
target_dt = self.distance_field(ch_target.detach()).float()
186187

187188
pred_error = (ch_input - ch_target) ** 2
188189
distance = pred_dt**self.alpha + target_dt**self.alpha

monai/metrics/hausdorff_distance.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from __future__ import annotations
1313

14-
import warnings
1514
from collections.abc import Sequence
1615
from typing import Any
1716

@@ -20,12 +19,12 @@
2019

2120
from monai.metrics.utils import (
2221
do_metric_reduction,
23-
get_mask_edges,
22+
get_edge_surface_distance,
2423
get_surface_distance,
2524
ignore_background,
2625
prepare_spacing,
2726
)
28-
from monai.utils import MetricReduction, convert_data_type
27+
from monai.utils import MetricReduction, convert_data_type, deprecated
2928

3029
from .metric import CumulativeIterationMetric
3130

@@ -180,31 +179,46 @@ def compute_hausdorff_distance(
180179
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")
181180

182181
batch_size, n_class = y_pred.shape[:2]
183-
hd = np.empty((batch_size, n_class))
182+
hd = torch.empty((batch_size, n_class), dtype=torch.float, device=y_pred.device)
184183

185184
img_dim = y_pred.ndim - 2
186185
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)
187186

188187
for b, c in np.ndindex(batch_size, n_class):
189-
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c])
190-
if not np.any(edges_gt):
191-
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
192-
if not np.any(edges_pred):
193-
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")
194-
195-
distance_1 = compute_percent_hausdorff_distance(
196-
edges_pred, edges_gt, distance_metric, percentile, spacing_list[b]
188+
_, distances, _ = get_edge_surface_distance(
189+
y_pred[b, c],
190+
y[b, c],
191+
distance_metric=distance_metric,
192+
spacing=spacing_list[b],
193+
symetric=not directed,
194+
class_index=c,
197195
)
198-
if directed:
199-
hd[b, c] = distance_1
200-
else:
201-
distance_2 = compute_percent_hausdorff_distance(
202-
edges_gt, edges_pred, distance_metric, percentile, spacing_list[b]
203-
)
204-
hd[b, c] = max(distance_1, distance_2)
205-
return convert_data_type(hd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]
196+
percentile_distances = [_compute_percentile_hausdorff_distance(d, percentile) for d in distances]
197+
max_distance = torch.max(torch.stack(percentile_distances))
198+
hd[b, c] = max_distance
199+
return hd
206200

207201

202+
def _compute_percentile_hausdorff_distance(
203+
surface_distance: torch.Tensor, percentile: float | None = None
204+
) -> torch.Tensor:
205+
"""
206+
This function is used to compute the Hausdorff distance.
207+
"""
208+
209+
# for both pred and gt do not have foreground
210+
if surface_distance.shape == (0,):
211+
return torch.tensor(torch.nan, dtype=torch.float, device=surface_distance.device)
212+
213+
if not percentile:
214+
return surface_distance.max() # type: ignore[no-any-return]
215+
216+
if 0 <= percentile <= 100:
217+
return torch.quantile(surface_distance, percentile / 100) # type: ignore[no-any-return]
218+
raise ValueError(f"percentile should be a value between 0 and 100, get {percentile}.")
219+
220+
221+
@deprecated(since="1.3.0", removed="1.5.0")
208222
def compute_percent_hausdorff_distance(
209223
edges_pred: np.ndarray,
210224
edges_gt: np.ndarray,
@@ -216,7 +230,9 @@ def compute_percent_hausdorff_distance(
216230
This function is used to compute the directed Hausdorff distance.
217231
"""
218232

219-
surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing)
233+
surface_distance: np.ndarray = get_surface_distance(
234+
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing
235+
) # type: ignore
220236

221237
# for both pred and gt do not have foreground
222238
if surface_distance.shape == (0,):

monai/metrics/surface_dice.py

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,14 @@
1111

1212
from __future__ import annotations
1313

14-
import warnings
1514
from collections.abc import Sequence
1615
from typing import Any
1716

1817
import numpy as np
1918
import torch
2019

21-
from monai.metrics.utils import (
22-
do_metric_reduction,
23-
get_mask_edges,
24-
get_surface_distance,
25-
ignore_background,
26-
prepare_spacing,
27-
)
28-
from monai.utils import MetricReduction, convert_data_type
20+
from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing
21+
from monai.utils import MetricReduction
2922

3023
from .metric import CumulativeIterationMetric
3124

@@ -251,47 +244,39 @@ def compute_surface_dice(
251244
if any(np.array(class_thresholds) < 0):
252245
raise ValueError("All class thresholds need to be >= 0.")
253246

254-
nsd = np.empty((batch_size, n_class))
247+
nsd = torch.empty((batch_size, n_class), device=y_pred.device, dtype=torch.float)
255248

256249
img_dim = y_pred.ndim - 2
257250
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)
258251

259252
for b, c in np.ndindex(batch_size, n_class):
253+
(edges_pred, edges_gt), (distances_pred_gt, distances_gt_pred), areas = get_edge_surface_distance( # type: ignore
254+
y_pred[b, c],
255+
y[b, c],
256+
distance_metric=distance_metric,
257+
spacing=spacing_list[b],
258+
use_subvoxels=use_subvoxels,
259+
symetric=True,
260+
class_index=c,
261+
)
262+
boundary_correct: int | torch.Tensor | float
263+
boundary_complete: int | torch.Tensor | float
260264
if not use_subvoxels:
261-
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=True)
262-
distances_pred_gt = get_surface_distance(
263-
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b]
264-
)
265-
distances_gt_pred = get_surface_distance(
266-
edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b]
267-
)
268-
269265
boundary_complete = len(distances_pred_gt) + len(distances_gt_pred)
270-
boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum(
266+
boundary_correct = torch.sum(distances_pred_gt <= class_thresholds[c]) + torch.sum(
271267
distances_gt_pred <= class_thresholds[c]
272268
)
273269
else:
274-
_spacing = spacing_list[b] if spacing_list[b] is not None else [1] * img_dim
275-
areas_pred: np.ndarray
276-
areas_gt: np.ndarray
277-
edges_pred, edges_gt, areas_pred, areas_gt = get_mask_edges( # type: ignore
278-
y_pred[b, c], y[b, c], crop=True, spacing=_spacing # type: ignore
279-
)
280-
dist_pred_to_gt = get_surface_distance(edges_pred, edges_gt, distance_metric, spacing=spacing_list[b])
281-
dist_gt_to_pred = get_surface_distance(edges_gt, edges_pred, distance_metric, spacing=spacing_list[b])
270+
areas_pred, areas_gt = areas # type: ignore
282271
areas_gt, areas_pred = areas_gt[edges_gt], areas_pred[edges_pred]
283-
boundary_complete = areas_gt.sum() + areas_pred.sum()
284-
gt_true = areas_gt[dist_gt_to_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0
285-
pred_true = areas_pred[dist_pred_to_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0
272+
boundary_complete = areas_gt.sum() + areas_pred.sum() # type: ignore
273+
gt_true = areas_gt[distances_gt_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0
274+
pred_true = areas_pred[distances_pred_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0
286275
boundary_correct = gt_true + pred_true
287-
if not np.any(edges_gt):
288-
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
289-
if not np.any(edges_pred):
290-
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")
291276
if boundary_complete == 0:
292277
# the class is neither present in the prediction, nor in the reference segmentation
293-
nsd[b, c] = np.nan
278+
nsd[b, c] = torch.nan
294279
else:
295280
nsd[b, c] = boundary_correct / boundary_complete
296281

297-
return convert_data_type(nsd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]
282+
return nsd

monai/metrics/surface_distance.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,13 @@
1111

1212
from __future__ import annotations
1313

14-
import warnings
1514
from collections.abc import Sequence
1615
from typing import Any
1716

1817
import numpy as np
1918
import torch
2019

21-
from monai.metrics.utils import (
22-
do_metric_reduction,
23-
get_mask_edges,
24-
get_surface_distance,
25-
ignore_background,
26-
prepare_spacing,
27-
)
20+
from monai.metrics.utils import do_metric_reduction, get_edge_surface_distance, ignore_background, prepare_spacing
2821
from monai.utils import MetricReduction, convert_data_type
2922

3023
from .metric import CumulativeIterationMetric
@@ -173,25 +166,21 @@ def compute_average_surface_distance(
173166
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")
174167

175168
batch_size, n_class = y_pred.shape[:2]
176-
asd = np.empty((batch_size, n_class))
169+
asd = torch.empty((batch_size, n_class), dtype=torch.float32, device=y_pred.device)
177170

178171
img_dim = y_pred.ndim - 2
179172
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)
180173

181174
for b, c in np.ndindex(batch_size, n_class):
182-
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c])
183-
if not np.any(edges_gt):
184-
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
185-
if not np.any(edges_pred):
186-
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")
187-
surface_distance = get_surface_distance(
188-
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b]
175+
_, distances, _ = get_edge_surface_distance(
176+
y_pred[b, c],
177+
y[b, c],
178+
distance_metric=distance_metric,
179+
spacing=spacing_list[b],
180+
symetric=symmetric,
181+
class_index=c,
189182
)
190-
if symmetric:
191-
surface_distance_2 = get_surface_distance(
192-
edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b]
193-
)
194-
surface_distance = np.concatenate([surface_distance, surface_distance_2])
195-
asd[b, c] = np.nan if surface_distance.shape == (0,) else surface_distance.mean()
183+
surface_distance = torch.cat(distances)
184+
asd[b, c] = torch.nan if surface_distance.shape == (0,) else surface_distance.mean()
196185

197186
return convert_data_type(asd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]

0 commit comments

Comments
 (0)