Skip to content

Commit 593657f

Browse files
authored
Merge branch 'dev' into worktree-fix-nested-compose-map-items
2 parents 46aecf2 + 65beb58 commit 593657f

7 files changed

Lines changed: 237 additions & 19 deletions

File tree

monai/apps/detection/utils/anchor_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def grid_anchors(self, grid_sizes: list[list[int]], strides: list[list[Tensor]])
253253
# compute anchor centers regarding to the image.
254254
# shifts_centers is [x_center, y_center] or [x_center, y_center, z_center]
255255
shifts_centers = [
256-
torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis]
256+
torch.arange(0, size[axis], dtype=torch.int32, device=device) * stride[axis] + stride[axis] // 2
257257
for axis in range(self.spatial_dims)
258258
]
259259

monai/apps/nuclick/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,14 +367,14 @@ def inclusion_map(self, mask, dtype):
367367

368368
def exclusion_map(self, others, dtype, jitter_range, drop_rate):
369369
point_mask = torch.zeros_like(others, dtype=dtype)
370-
if np.random.choice([True, False], p=[drop_rate, 1 - drop_rate]):
370+
if self.R.choice([True, False], p=[drop_rate, 1 - drop_rate]):
371371
return point_mask
372372

373373
max_x = point_mask.shape[0] - 1
374374
max_y = point_mask.shape[1] - 1
375375
stats = measure.regionprops(convert_to_numpy(others))
376376
for stat in stats:
377-
if np.random.choice([True, False], p=[drop_rate, 1 - drop_rate]):
377+
if self.R.choice([True, False], p=[drop_rate, 1 - drop_rate]):
378378
continue
379379

380380
# random jitter

monai/losses/image_dissimilarity.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,14 +111,16 @@ def __init__(
111111
raise ValueError(f"kernel_size must be odd, got {self.kernel_size}")
112112

113113
_kernel = look_up_option(kernel_type, kernel_dict)
114-
self.kernel = _kernel(self.kernel_size)
115-
self.kernel.require_grads = False
116-
self.kernel_vol = self.get_kernel_vol()
114+
self.kernel: torch.Tensor
115+
self.kernel_vol: torch.Tensor
116+
self.register_buffer("kernel", _kernel(self.kernel_size), persistent=False)
117+
self.register_buffer("kernel_vol", self.get_kernel_vol(), persistent=False)
117118

118119
self.smooth_nr = float(smooth_nr)
119120
self.smooth_dr = float(smooth_dr)
120121

121-
def get_kernel_vol(self):
122+
def get_kernel_vol(self) -> torch.Tensor:
123+
assert self.kernel is not None
122124
vol = self.kernel
123125
for _ in range(self.ndim - 1):
124126
vol = torch.matmul(vol.unsqueeze(-1), self.kernel.unsqueeze(0))
@@ -138,6 +140,8 @@ def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
138140
raise ValueError(f"ground truth has differing shape ({target.shape}) from pred ({pred.shape})")
139141

140142
t2, p2, tp = target * target, pred * pred, target * pred
143+
assert self.kernel is not None
144+
assert self.kernel_vol is not None
141145
kernel, kernel_vol = self.kernel.to(pred), self.kernel_vol.to(pred)
142146
kernels = [kernel] * self.ndim
143147
# sum over kernel

monai/metrics/meandice.py

Lines changed: 98 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,17 @@
1313

1414
import torch
1515

16-
from monai.metrics.utils import do_metric_reduction
16+
from monai.metrics.utils import compute_voronoi_regions_fast, do_metric_reduction
1717
from monai.utils import MetricReduction, deprecated_arg
18+
from monai.utils.module import optional_import
1819

1920
from .metric import CumulativeIterationMetric
2021

22+
scipy_ndimage, has_scipy_ndimage = optional_import("scipy.ndimage")
23+
cupy, has_cupy = optional_import("cupy")
24+
cupy_ndimage, has_cupy_ndimage = optional_import("cupyx.scipy.ndimage")
25+
26+
2127
__all__ = ["DiceMetric", "compute_dice", "DiceHelper"]
2228

2329

@@ -41,6 +47,18 @@ class DiceMetric(CumulativeIterationMetric):
4147
image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction
4248
and ground truth is BCHW[D].
4349
50+
The `per_component=True` approach computes the Dice metric on a per-connected component basis in the ground truth segmentation,
51+
ensuring equal weighting for each component regardless of its size. This method eliminates biases in traditional metrics,
52+
providing a more balanced evaluation, particularly in scenarios where object size does not correlate with clinical relevance.
53+
This provides a more granular evaluation of segmentation quality, especially useful when dealing with fragmented or
54+
disconnected objects in the foreground.
55+
Note:
56+
- The input prediction (`y_pred`) and ground truth (`y`) must both have 2 channels (foreground/background),
57+
with binary segmentation (0 for background, 1 for foreground). That is, this assumes the shape of both prediction
58+
and ground truth is B2HW[D].
59+
- This method cannot be used with multiclass segmentation.
60+
For more information, refer to the original paper: https://arxiv.org/abs/2410.18684
61+
4462
The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
4563
4664
Further information can be found in the official
@@ -95,6 +113,9 @@ class DiceMetric(CumulativeIterationMetric):
95113
If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True,
96114
the index begins at "0", otherwise at "1". It can also take a list of label names.
97115
The outcome will then be returned as a dictionary.
116+
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
117+
computed for each connected component in the ground truth, and then averaged. This requires binary
118+
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
98119
99120
"""
100121

@@ -106,6 +127,7 @@ def __init__(
106127
ignore_empty: bool = True,
107128
num_classes: int | None = None,
108129
return_with_label: bool | list[str] = False,
130+
per_component: bool = False,
109131
) -> None:
110132
super().__init__()
111133
self.include_background = include_background
@@ -114,13 +136,15 @@ def __init__(
114136
self.ignore_empty = ignore_empty
115137
self.num_classes = num_classes
116138
self.return_with_label = return_with_label
139+
self.per_component = per_component
117140
self.dice_helper = DiceHelper(
118141
include_background=self.include_background,
119142
reduction=MetricReduction.NONE,
120143
get_not_nans=False,
121144
apply_argmax=False,
122145
ignore_empty=self.ignore_empty,
123146
num_classes=self.num_classes,
147+
per_component=self.per_component,
124148
)
125149

126150
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
@@ -175,6 +199,7 @@ def compute_dice(
175199
include_background: bool = True,
176200
ignore_empty: bool = True,
177201
num_classes: int | None = None,
202+
per_component: bool = False,
178203
) -> torch.Tensor:
179204
"""
180205
Computes Dice score metric for a batch of predictions. This performs the same computation as
@@ -192,6 +217,9 @@ def compute_dice(
192217
num_classes: number of input channels (always including the background). When this is ``None``,
193218
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
194219
single-channel class indices and the number of classes is not automatically inferred from data.
220+
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
221+
computed for each connected component in the ground truth, and then averaged. This requires binary
222+
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
195223
196224
Returns:
197225
Dice scores per batch and per class, (shape: [batch_size, num_classes]).
@@ -204,6 +232,7 @@ def compute_dice(
204232
apply_argmax=False,
205233
ignore_empty=ignore_empty,
206234
num_classes=num_classes,
235+
per_component=per_component,
207236
)(y_pred=y_pred, y=y)
208237

209238

@@ -246,6 +275,9 @@ class DiceHelper:
246275
num_classes: number of input channels (always including the background). When this is ``None``,
247276
``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
248277
single-channel class indices and the number of classes is not automatically inferred from data.
278+
per_component: whether to compute the Dice metric per connected component. If `True`, the metric will be
279+
computed for each connected component in the ground truth, and then averaged. This requires binary
280+
segmentations with 2 channels (background + foreground) as input. This is a more fine-grained computation.
249281
"""
250282

251283
@deprecated_arg("softmax", "1.5", "1.7", "Use `apply_argmax` instead.", new_name="apply_argmax")
@@ -262,6 +294,7 @@ def __init__(
262294
num_classes: int | None = None,
263295
sigmoid: bool | None = None,
264296
softmax: bool | None = None,
297+
per_component: bool = False,
265298
) -> None:
266299
# handling deprecated arguments
267300
if sigmoid is not None:
@@ -277,6 +310,50 @@ def __init__(
277310
self.activate = activate
278311
self.ignore_empty = ignore_empty
279312
self.num_classes = num_classes
313+
self.per_component = per_component
314+
315+
def compute_cc_dice(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
316+
"""
317+
Compute per-component Dice for a single batch item.
318+
319+
Args:
320+
y_pred (torch.Tensor): Predictions with shape (1, 2, D, H, W) or (1, 2, H, W).
321+
y (torch.Tensor): Ground truth with shape (1, 2, D, H, W) or (1, 2, H, W).
322+
323+
Returns:
324+
torch.Tensor: Mean Dice over connected components.
325+
"""
326+
if y_pred.ndim == y.ndim:
327+
y_pred_idx = torch.argmax(y_pred, dim=1)
328+
y_idx = torch.argmax(y, dim=1)
329+
else:
330+
y_pred_idx = y_pred
331+
y_idx = y
332+
if y_idx[0].sum() == 0:
333+
if self.ignore_empty:
334+
data = torch.tensor(float("nan"), device=y_idx.device)
335+
elif y_pred_idx.sum() == 0:
336+
data = torch.tensor(1.0, device=y_idx.device)
337+
else:
338+
data = torch.tensor(0.0, device=y_idx.device)
339+
else:
340+
cc_assignment = compute_voronoi_regions_fast(y_idx[0])
341+
if cc_assignment.device != y_idx.device:
342+
cc_assignment = cc_assignment.to(y_idx.device)
343+
uniq, inv = torch.unique(cc_assignment.view(-1), return_inverse=True)
344+
nof_components = uniq.numel()
345+
code = (y_idx.view(-1) << 1) | y_pred_idx.view(-1)
346+
idx = (inv << 2) | code
347+
hist = torch.bincount(idx, minlength=nof_components * 4).reshape(-1, 4)
348+
_, fp, fn, tp = hist[:, 0], hist[:, 1], hist[:, 2], hist[:, 3]
349+
denom = 2 * tp + fp + fn
350+
dice_scores = torch.where(
351+
denom > 0, (2 * tp).float() / denom.float(), torch.tensor(1.0, device=denom.device)
352+
)
353+
data = dice_scores.unsqueeze(-1)
354+
data = torch.nan_to_num(data)
355+
data = data.reshape(-1, 1)
356+
return torch.stack([data.mean()])
280357

281358
def compute_channel(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
282359
"""
@@ -305,6 +382,9 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
305382
y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
306383
the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
307384
y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
385+
386+
Raises:
387+
ValueError: when the shapes of `y_pred` and `y` are not compatible for the per-component computation.
308388
"""
309389
_apply_argmax, _threshold = self.apply_argmax, self.threshold
310390
if self.num_classes is None:
@@ -322,15 +402,31 @@ def __call__(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor | tupl
322402
y_pred = torch.sigmoid(y_pred)
323403
y_pred = y_pred > 0.5
324404

325-
first_ch = 0 if self.include_background else 1
405+
if self.per_component:
406+
if y_pred.ndim not in (4, 5) or y.ndim not in (4, 5) or y_pred.shape[1] != 2 or y.shape[1] != 2:
407+
same_rank = y_pred.ndim == y.ndim and y_pred.ndim in (4, 5)
408+
binary_channels = y_pred.shape[1] == 2 and y.shape[1] == 2
409+
same_shape = y_pred.shape == y.shape
410+
if not (same_rank and binary_channels and same_shape):
411+
raise ValueError(
412+
"per_component requires matching 4D/5D binary tensors "
413+
"(B, 2, H, W) or (B, 2, D, H, W). "
414+
f"Got y_pred={tuple(y_pred.shape)}, y={tuple(y.shape)}."
415+
)
416+
417+
first_ch = 0 if self.include_background and not self.per_component else 1
326418
data = []
327419
for b in range(y_pred.shape[0]):
420+
if self.per_component:
421+
data.append(self.compute_cc_dice(y_pred=y_pred[b].unsqueeze(0), y=y[b].unsqueeze(0)).reshape(-1))
422+
continue
328423
c_list = []
329424
for c in range(first_ch, n_pred_ch) if n_pred_ch > 1 else [1]:
330425
x_pred = (y_pred[b, 0] == c) if (y_pred.shape[1] == 1) else y_pred[b, c].bool()
331426
x = (y[b, 0] == c) if (y.shape[1] == 1) else y[b, c]
332427
c_list.append(self.compute_channel(x_pred, x))
333428
data.append(torch.stack(c_list))
429+
334430
data = torch.stack(data, dim=0).contiguous() # type: ignore
335431

336432
f, not_nans = do_metric_reduction(data, self.reduction) # type: ignore

monai/metrics/utils.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@
3939
distance_transform_edt, _ = optional_import("scipy.ndimage", name="distance_transform_edt")
4040
distance_transform_cdt, _ = optional_import("scipy.ndimage", name="distance_transform_cdt")
4141

42+
scipy_ndimage, has_scipy_ndimage = optional_import("scipy.ndimage")
43+
cupy, has_cupy = optional_import("cupy")
44+
cupy_ndimage, has_cupy_ndimage = optional_import("cupyx.scipy.ndimage")
45+
4246
__all__ = [
4347
"ignore_background",
4448
"do_metric_reduction",
@@ -462,6 +466,59 @@ def prepare_spacing(
462466
)
463467

464468

469+
def compute_voronoi_regions_fast(labels: np.ndarray | torch.Tensor) -> torch.Tensor:
470+
"""
471+
Voronoi assignment to connected components (CPU, single EDT) without cc3d.
472+
Returns the ID of the nearest component for each voxel.
473+
474+
Args:
475+
labels (np.ndarray | torch.Tensor): Label map where values > 0 are seeds.
476+
477+
Raises:
478+
RuntimeError: when `scipy.ndimage` is not available.
479+
ValueError: when `labels` has fewer than two dimensions.
480+
481+
Returns:
482+
torch.Tensor: Voronoi region IDs (int32) on CPU.
483+
"""
484+
if isinstance(labels, torch.Tensor) and labels.is_cuda and has_cupy and has_cupy_ndimage:
485+
xp = cupy
486+
nd_distance_transform_edt = cupy_ndimage.distance_transform_edt
487+
nd_generate_binary_structure = cupy_ndimage.generate_binary_structure
488+
nd_label = cupy_ndimage.label
489+
x = cupy.asarray(labels.detach())
490+
else:
491+
xp = np
492+
nd_distance_transform_edt = scipy_ndimage.distance_transform_edt
493+
nd_generate_binary_structure = scipy_ndimage.generate_binary_structure
494+
nd_label = scipy_ndimage.label
495+
496+
if not has_scipy_ndimage:
497+
raise RuntimeError("scipy.ndimage is required for per_component Dice computation.")
498+
499+
if isinstance(labels, torch.Tensor):
500+
warnings.warn(
501+
"Voronoi computation is running on CPU. "
502+
"To accelerate, move the input tensor to GPU and ensure 'cupy' with 'cupyx.scipy.ndimage' is installed."
503+
)
504+
x = labels.cpu().numpy()
505+
else:
506+
x = np.asarray(labels)
507+
rank = conn_rank = x.ndim
508+
structure = nd_generate_binary_structure(rank=rank, connectivity=conn_rank)
509+
cc, num = nd_label(x > 0, structure=structure)
510+
if num == 0:
511+
return torch.zeros_like(torch.from_numpy(x), dtype=torch.int32)
512+
edt_input = xp.ones(cc.shape, dtype=xp.uint8)
513+
edt_input[cc > 0] = 0
514+
indices = nd_distance_transform_edt(edt_input, sampling=None, return_distances=False, return_indices=True)
515+
voronoi = cc[tuple(indices)]
516+
if xp is cupy:
517+
return torch.as_tensor(cupy.asnumpy(voronoi), dtype=torch.int32)
518+
else:
519+
return torch.as_tensor(voronoi, dtype=torch.int32)
520+
521+
465522
ENCODING_KERNEL = {2: [[8, 4], [2, 1]], 3: [[[128, 64], [32, 16]], [[8, 4], [2, 1]]]}
466523

467524

tests/apps/detection/utils/test_anchor_box.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ class TestAnchorGenerator(unittest.TestCase):
4444
@parameterized.expand(TEST_CASES_2D)
4545
def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes):
4646
torch_anchor_utils, _ = optional_import("torchvision.models.detection.anchor_utils")
47-
image_list, _ = optional_import("torchvision.models.detection.image_list")
4847

49-
# test it behaves the same with torchvision for 2d
48+
# test it behaves for new functionality of centered anchors
49+
# pytorch does not follow this functionality
5050
anchor = AnchorGenerator(**input_param, indexing="xy")
5151
anchor_ref = torch_anchor_utils.AnchorGenerator(**input_param)
5252
for a, a_f in zip(anchor.cell_anchors, anchor_ref.cell_anchors):
@@ -56,15 +56,18 @@ def test_anchor_2d(self, input_param, image_shape, feature_maps_shapes):
5656

5757
grid_sizes = [[2, 2], [1, 1]]
5858
strides = [[torch.tensor(1), torch.tensor(2)], [torch.tensor(2), torch.tensor(4)]]
59-
for a, a_f in zip(anchor.grid_anchors(grid_sizes, strides), anchor_ref.grid_anchors(grid_sizes, strides)):
60-
assert_allclose(a, a_f, type_test=True, device_test=False, atol=1e-3)
6159

62-
images = torch.rand(image_shape)
63-
feature_maps = tuple(torch.rand(fs) for fs in feature_maps_shapes)
64-
result = anchor(images, feature_maps)
65-
result_ref = anchor_ref(image_list.ImageList(images, ([123, 122],)), feature_maps)
66-
for a, a_f in zip(result, result_ref):
67-
assert_allclose(a, a_f, type_test=True, device_test=False, atol=0.1)
60+
monai_anchors = anchor.grid_anchors(grid_sizes, strides)
61+
torchvision_anchors = anchor_ref.grid_anchors(grid_sizes, strides)
62+
63+
for a, a_f, s in zip(monai_anchors, torchvision_anchors, strides):
64+
stride_y, stride_x = s
65+
66+
offset_x = stride_x // 2
67+
offset_y = stride_y // 2
68+
offset = torch.tensor([offset_x, offset_y, offset_x, offset_y], dtype=a_f.dtype, device=a_f.device)
69+
70+
assert_allclose(a, a_f + offset, type_test=True, device_test=False, atol=1e-3)
6871

6972
@parameterized.expand(TEST_CASES_2D)
7073
def test_script_2d(self, input_param, image_shape, feature_maps_shapes):

0 commit comments

Comments
 (0)