Skip to content

Commit e15da1c

Browse files
myronpre-commit-ci[bot]wyli
authored
CheckLabelShaped to check and correct label shapes (#5606)
adds CheckLabelShaped transform under monai.apps.auto3dseg: - to check if image shape and label shape are the same. Currently, we usually don't check for it, which results in potentially a non-descriptive crash later, when computing loss or metric. - to correct label shape if it is very similar to image shape. Unfortunately some datasets, such as Hecktor22, has some labels (masks) with a slightly different shapes (difference in only 1 voxel). And because of that a user will have to either re-save the whole dataset, or account for it during the training loop (which is inconvenient). this utility auto corrects it. We also already have this auto-correction in DataAnalyzer, and with this transform, DataAnalyzer logic is simplified. Signed-off-by: myron <amyronenko@nvidia.com> Signed-off-by: Wenqi Li <wenqil@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com> Co-authored-by: Wenqi Li <wenqil@nvidia.com>
1 parent 4cf6cf8 commit e15da1c

5 files changed

Lines changed: 98 additions & 25 deletions

File tree

.github/workflows/pythonapp.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ jobs:
100100
python -m pip install -r requirements-dev.txt
101101
python -m pip list
102102
python setup.py develop # test no compile installation
103+
shell: bash
104+
- if: runner.os != 'windows'
105+
name: Run compiled (${{ runner.os }})
106+
run: |
103107
python setup.py develop --uninstall
104108
BUILD_MONAI=1 python setup.py develop # compile the cpp extensions
105109
shell: bash

monai/apps/auto3dseg/data_analyzer.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
import torch
1818

19+
from monai.apps.auto3dseg.transforms import EnsureSameShaped
1920
from monai.apps.utils import get_logger
2021
from monai.auto3dseg import SegSummarizer
2122
from monai.auto3dseg.utils import datafold_read
@@ -69,9 +70,11 @@ class DataAnalyzer:
6970
hist_range: ranges to compute histogram for each image channel.
7071
fmt: format used to save the analysis results. Defaults to "yaml".
7172
histogram_only: whether to only compute histograms. Defaults to False.
73+
extra_params: other optional arguments. Currently supported arguments are :
74+
'allowed_shape_difference' (default 5) can be used to change the default tolerance of
75+
the allowed shape differences between the image and label items. In case of shape mismatch below
76+
the tolerance, the label image will be resized to match the image using nearest interpolation.
7277
73-
Raises:
74-
ValueError if device is GPU and worker > 0.
7578
7679
Examples:
7780
.. code-block:: python
@@ -121,6 +124,7 @@ def __init__(
121124
hist_range: Optional[list] = None,
122125
fmt: Optional[str] = "yaml",
123126
histogram_only: bool = False,
127+
**extra_params,
124128
):
125129
if path.isfile(output_path):
126130
warnings.warn(f"File {output_path} already exists and will be overwritten.")
@@ -139,6 +143,7 @@ def __init__(
139143
self.hist_range: list = [-500, 500] if hist_range is None else hist_range
140144
self.fmt = fmt
141145
self.histogram_only = histogram_only
146+
self.extra_params = extra_params
142147

143148
@staticmethod
144149
def _check_data_uniformity(keys: List[str], result: Dict):
@@ -206,6 +211,17 @@ def get_all_case_stats(self, key="training", transform_list=None):
206211
EnsureTyped(keys=keys, data_type="tensor", dtype=torch.float),
207212
Orientationd(keys=keys, axcodes="RAS"),
208213
]
214+
if self.label_key is not None:
215+
216+
allowed_shape_difference = self.extra_params.pop("allowed_shape_difference", 5)
217+
transform_list.append(
218+
EnsureSameShaped(
219+
keys=self.label_key,
220+
source_key=self.image_key,
221+
allowed_shape_difference=allowed_shape_difference,
222+
)
223+
)
224+
209225
transform = Compose(transform_list)
210226

211227
files, _ = datafold_read(datalist=self.datalist, basedir=self.dataroot, fold=-1, key=key)

monai/apps/auto3dseg/transforms.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import warnings
13+
from typing import Dict, Hashable, Mapping
14+
15+
import numpy as np
16+
import torch
17+
18+
from monai.config import KeysCollection
19+
from monai.networks.utils import pytorch_after
20+
from monai.transforms import MapTransform
21+
22+
23+
class EnsureSameShaped(MapTransform):
24+
"""
25+
Checks if segmentation label images (in keys) have the same spatial shape as the main image (in source_key),
26+
and raise an error if the shapes are significantly different.
27+
If the shapes are only slightly different (within an allowed_shape_difference in each dim), then resize the label using
28+
nearest interpolation. This transform is designed to correct datasets with slight label shape mismatches.
29+
Generally image and segmentation label must have the same spatial shape, however some public datasets are having slight
30+
shape mismatches, which will cause potential crashes when calculating loss or metric functions.
31+
"""
32+
33+
def __init__(
34+
self,
35+
keys: KeysCollection = "label",
36+
allow_missing_keys: bool = False,
37+
source_key: str = "image",
38+
allowed_shape_difference: int = 5,
39+
) -> None:
40+
"""
41+
Args:
42+
keys: keys of the corresponding items to be compared to the source_key item shape.
43+
allow_missing_keys: do not raise exception if key is missing.
44+
source_key: key of the item with the reference shape.
45+
allowed_shape_difference: raises error if shapes are different more than this value in any dimension,
46+
otherwise corrects for the shape mismatch using nearest interpolation.
47+
48+
"""
49+
super().__init__(keys=keys, allow_missing_keys=allow_missing_keys)
50+
self.source_key = source_key
51+
self.allowed_shape_difference = allowed_shape_difference
52+
53+
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
54+
d = dict(data)
55+
image_shape = d[self.source_key].shape[1:]
56+
for key in self.key_iterator(d):
57+
label_shape = d[key].shape[1:]
58+
if label_shape != image_shape:
59+
if np.allclose(list(label_shape), list(image_shape), atol=self.allowed_shape_difference):
60+
warnings.warn(
61+
f"The {key} with shape {label_shape} was resized to match the source shape {image_shape},"
62+
f"the meta-data was not updated."
63+
)
64+
d[key] = torch.nn.functional.interpolate(
65+
input=d[key].unsqueeze(0),
66+
size=image_shape,
67+
mode="nearest-exact" if pytorch_after(1, 11) else "nearest",
68+
).squeeze(0)
69+
else:
70+
raise ValueError(f"The {key} shape {label_shape} is different from the source shape {image_shape}.")
71+
return d

monai/auto3dseg/analyzer.py

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import numpy as np
1818
import torch
19-
import torch.nn.functional as F
2019

2120
from monai.apps.utils import get_logger
2221
from monai.auto3dseg.operations import Operations, SampleOperations, SummaryOperations
@@ -33,7 +32,7 @@
3332
from monai.data import MetaTensor, affine_to_spacing
3433
from monai.transforms.transform import MapTransform
3534
from monai.transforms.utils_pytorch_numpy_unification import sum, unique
36-
from monai.utils import convert_to_numpy, pytorch_after
35+
from monai.utils import convert_to_numpy
3736
from monai.utils.enums import DataStatsKeys, ImageStatsKeys, LabelStatsKeys
3837
from monai.utils.misc import ImageMetaKey, label_union
3938

@@ -326,16 +325,7 @@ def __call__(self, data) -> dict:
326325
ndas_label = d[self.label_key] # (H,W,D)
327326

328327
if ndas_label.shape != ndas[0].shape:
329-
# if image and label shapes are different, check if they are close
330-
if np.allclose(list(ndas_label.shape), list(ndas[0].shape), atol=10):
331-
logger.info(f" Label shape {ndas_label.shape} is slightly different from image shape {ndas[0].shape}")
332-
ndas_label = F.interpolate(
333-
input=ndas_label.unsqueeze(0).unsqueeze(0),
334-
size=list(ndas[0].shape),
335-
mode="nearest-exact" if pytorch_after(1, 11) else "nearest",
336-
)[0, 0]
337-
else:
338-
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
328+
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
339329

340330
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
341331
nda_foregrounds = [nda if nda.numel() > 0 else MetaTensor([0.0]) for nda in nda_foregrounds]
@@ -465,16 +455,7 @@ def __call__(self, data):
465455
ndas_label = d[self.label_key] # (H,W,D)
466456

467457
if ndas_label.shape != ndas[0].shape:
468-
# if image and label shapes are different, check if they are close
469-
if np.allclose(list(ndas_label.shape), list(ndas[0].shape), atol=10):
470-
logger.info(f" Label shape {ndas_label.shape} is slightly different from image shape {ndas[0].shape}")
471-
ndas_label = F.interpolate(
472-
input=ndas_label.unsqueeze(0).unsqueeze(0),
473-
size=list(ndas[0].shape),
474-
mode="nearest-exact" if pytorch_after(1, 11) else "nearest",
475-
)[0, 0]
476-
else:
477-
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
458+
raise ValueError(f"Label shape {ndas_label.shape} is different from image shape {ndas[0].shape}")
478459

479460
nda_foregrounds = [get_foreground_label(nda, ndas_label) for nda in ndas]
480461
nda_foregrounds = [nda if nda.numel() > 0 else torch.Tensor([0]) for nda in nda_foregrounds]

tests/test_spacing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from monai.data.utils import affine_to_spacing
2121
from monai.transforms import Spacing
2222
from monai.utils import fall_back_tuple
23-
from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose
23+
from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose, skip_if_quick
2424

2525
TESTS = []
2626
for device in TEST_DEVICES:
@@ -261,6 +261,7 @@
261261
TEST_INVERSE.append([*d, recompute, align, scale_extent])
262262

263263

264+
@skip_if_quick
264265
class TestSpacingCase(unittest.TestCase):
265266
@parameterized.expand(TESTS)
266267
def test_spacing(self, init_param, img, affine, data_param, expected_output, device):

0 commit comments

Comments
 (0)