Skip to content

Commit 8dd004a

Browse files
authored
extend SurfaceDiceMetric for 3D images (#6549)
### Description This PR extends the `SurfaceDiceMetric` for 3D images. The implementation already uses generic functions to obtain the boundary edges and compute the distance, the extension just - removes the assertion that the input is 2D (`[B,C,W,H]`). - updates the docstrings - adds a test case in `TestAllSurfaceDiceMetrics.test_tolerance_euclidean_distance_3d` fixes #5906 mentioning #4103 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Bryn Lloyd <lloyd@itis.swiss>
1 parent 960249f commit 8dd004a

2 files changed

Lines changed: 63 additions & 12 deletions

File tree

monai/metrics/surface_dice.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,14 @@
3232

3333
class SurfaceDiceMetric(CumulativeIterationMetric):
3434
"""
35-
Computes the Normalized Surface Distance (NSD) for each batch sample and class of
35+
Computes the Normalized Surface Dice (NSD) for each batch sample and class of
3636
predicted segmentations `y_pred` and corresponding reference segmentations `y` according to equation :eq:`nsd`.
37-
This implementation supports 2D images. For 3D images, please refer to DeepMind's implementation
38-
https://github.com/deepmind/surface-distance.
37+
This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images.
38+
Be aware that the computation of boundaries is different from DeepMind's implementation
39+
https://github.com/deepmind/surface-distance. In this implementation, the length/area of a segmentation boundary is
40+
interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary
41+
depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430).
42+
This issue is discussed here: https://github.com/Project-MONAI/MONAI/issues/4103.
3943
4044
The class- and batch sample-wise NSD values can be aggregated with the function `aggregate`.
4145
@@ -79,9 +83,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
7983
r"""
8084
Args:
8185
y_pred: Predicted segmentation, typically segmentation model output.
82-
It must be a one-hot encoded, batch-first tensor [B,C,H,W].
86+
It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].
8387
y: Reference segmentation.
84-
It must be a one-hot encoded, batch-first tensor [B,C,H,W].
88+
It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].
8589
kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric.
8690
``spacing``: spacing of pixel (or voxel). This parameter is relevant only
8791
if ``distance_metric`` is set to ``"euclidean"``.
@@ -168,17 +172,17 @@ def compute_surface_dice(
168172
will be returned for this class. In the case of a class being present in only one of predicted segmentation or
169173
reference segmentation, the class NSD will be 0.
170174
171-
This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D images.
175+
This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images.
172176
Be aware that the computation of boundaries is different from DeepMind's implementation
173177
https://github.com/deepmind/surface-distance. In this implementation, the length of a segmentation boundary is
174178
interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary
175179
depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430).
176180
177181
Args:
178182
y_pred: Predicted segmentation, typically segmentation model output.
179-
It must be a one-hot encoded, batch-first tensor [B,C,H,W].
183+
It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].
180184
y: Reference segmentation.
181-
It must be a one-hot encoded, batch-first tensor [B,C,H,W].
185+
It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].
182186
class_thresholds: List of class-specific thresholds.
183187
The thresholds relate to the acceptable amount of deviation in the segmentation boundary in pixels.
184188
Each threshold needs to be a finite, non-negative number.
@@ -215,8 +219,8 @@ def compute_surface_dice(
215219
if not isinstance(y_pred, torch.Tensor) or not isinstance(y, torch.Tensor):
216220
raise ValueError("y_pred and y must be PyTorch Tensor.")
217221

218-
if y_pred.ndimension() != 4 or y.ndimension() != 4:
219-
raise ValueError("y_pred and y should have four dimensions: [B,C,H,W].")
222+
if y_pred.ndimension() not in (4, 5) or y.ndimension() not in (4, 5):
223+
raise ValueError("y_pred and y should be one-hot encoded: [B,C,H,W] or [B,C,H,W,D].")
220224

221225
if y_pred.shape != y.shape:
222226
raise ValueError(

tests/test_surface_dice.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,53 @@ def test_tolerance_euclidean_distance(self):
128128
np.testing.assert_array_equal(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0))
129129
np.testing.assert_equal(not_nans.cpu(), torch.tensor(2))
130130

131+
def test_tolerance_euclidean_distance_3d(self):
132+
batch_size = 2
133+
n_class = 2
134+
predictions = torch.zeros((batch_size, 200, 110, 80), dtype=torch.int64, device=_device)
135+
labels = torch.zeros((batch_size, 200, 110, 80), dtype=torch.int64, device=_device)
136+
predictions[0, :, :, 20:] = 1
137+
labels[0, :, :, 30:] = 1 # offset by 10
138+
predictions_hot = F.one_hot(predictions, num_classes=n_class).permute(0, 4, 1, 2, 3)
139+
labels_hot = F.one_hot(labels, num_classes=n_class).permute(0, 4, 1, 2, 3)
140+
141+
sd0 = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True)
142+
res0 = sd0(predictions_hot, labels_hot)
143+
agg0 = sd0.aggregate() # aggregation: nanmean across image then nanmean across batch
144+
sd0_nans = SurfaceDiceMetric(class_thresholds=[0, 0], include_background=True, get_not_nans=True)
145+
res0_nans = sd0_nans(predictions_hot, labels_hot)
146+
agg0_nans, not_nans = sd0_nans.aggregate()
147+
148+
np.testing.assert_array_equal(res0.cpu(), res0_nans.cpu())
149+
np.testing.assert_equal(res0.device, predictions.device)
150+
np.testing.assert_array_equal(agg0.cpu(), agg0_nans.cpu())
151+
np.testing.assert_equal(agg0.device, predictions.device)
152+
153+
res1 = SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels_hot)
154+
res10 = SurfaceDiceMetric(class_thresholds=[10, 10], include_background=True)(predictions_hot, labels_hot)
155+
res11 = SurfaceDiceMetric(class_thresholds=[11, 11], include_background=True)(predictions_hot, labels_hot)
156+
157+
for res in [res0, res1, res10, res11]:
158+
assert res.shape == torch.Size([2, 2])
159+
160+
assert res0[0, 0] < res1[0, 0] < res10[0, 0]
161+
assert res0[0, 1] < res1[0, 1] < res10[0, 1]
162+
np.testing.assert_array_equal(res10.cpu(), res11.cpu())
163+
164+
expected_res0 = np.zeros((batch_size, n_class))
165+
expected_res0[0, 1] = 1 - (200 * 110 + 198 * 108 + 9 * 200 * 2 + 9 * 108 * 2) / (
166+
200 * 110 * 4 + (58 + 48) * 200 * 2 + (58 + 48) * 108 * 2
167+
)
168+
expected_res0[0, 0] = 1 - (200 * 110 + 198 * 108 + 9 * 200 * 2 + 9 * 108 * 2) / (
169+
200 * 110 * 4 + (28 + 18) * 200 * 2 + (28 + 18) * 108 * 2
170+
)
171+
expected_res0[1, 0] = 1
172+
expected_res0[1, 1] = np.nan
173+
for b, c in np.ndindex(batch_size, n_class):
174+
np.testing.assert_allclose(expected_res0[b, c], res0[b, c].cpu())
175+
np.testing.assert_array_equal(agg0.cpu(), np.nanmean(np.nanmean(expected_res0, axis=1), axis=0))
176+
np.testing.assert_equal(not_nans.cpu(), torch.tensor(2))
177+
131178
def test_tolerance_all_distances(self):
132179
batch_size = 1
133180
n_class = 2
@@ -262,10 +309,10 @@ def test_asserts(self):
262309
# wrong dimensions
263310
with self.assertRaises(ValueError) as context:
264311
SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions, labels_hot)
265-
self.assertEqual("y_pred and y should have four dimensions: [B,C,H,W].", str(context.exception))
312+
self.assertEqual("y_pred and y should be one-hot encoded: [B,C,H,W] or [B,C,H,W,D].", str(context.exception))
266313
with self.assertRaises(ValueError) as context:
267314
SurfaceDiceMetric(class_thresholds=[1, 1], include_background=True)(predictions_hot, labels)
268-
self.assertEqual("y_pred and y should have four dimensions: [B,C,H,W].", str(context.exception))
315+
self.assertEqual("y_pred and y should be one-hot encoded: [B,C,H,W] or [B,C,H,W,D].", str(context.exception))
269316

270317
# mismatch of shape of input tensors
271318
input_bad_shape = torch.clone(predictions_hot)

0 commit comments

Comments
 (0)