Skip to content

Commit 7a38c3b

Browse files
KumoLiuwyli
andauthored
Fix min max function error in safe_dtype_range (#5706)
Fixes #5705. ### Description fix min max function error in `safe_dtype_range` add `safe` flag in `ShiftIntensity` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: KumoLiu <yunl@nvidia.com> Signed-off-by: Wenqi Li <wenqil@nvidia.com> Co-authored-by: Wenqi Li <wenqil@nvidia.com>
1 parent c489de0 commit 7a38c3b

5 files changed

Lines changed: 24 additions & 8 deletions

File tree

monai/transforms/intensity/array.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,12 +220,15 @@ class ShiftIntensity(Transform):
220220
221221
Args:
222222
offset: offset value to shift the intensity of image.
223+
safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
224+
E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
223225
"""
224226

225227
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
226228

227-
def __init__(self, offset: float) -> None:
229+
def __init__(self, offset: float, safe: bool = False) -> None:
228230
self.offset = offset
231+
self.safe = safe
229232

230233
def __call__(self, img: NdarrayOrTensor, offset: Optional[float] = None) -> NdarrayOrTensor:
231234
"""
@@ -235,7 +238,7 @@ def __call__(self, img: NdarrayOrTensor, offset: Optional[float] = None) -> Ndar
235238
img = convert_to_tensor(img, track_meta=get_track_meta())
236239
offset = self.offset if offset is None else offset
237240
out = img + offset
238-
out, *_ = convert_data_type(data=out, dtype=img.dtype)
241+
out, *_ = convert_data_type(data=out, dtype=img.dtype, safe=self.safe)
239242

240243
return out
241244

@@ -247,11 +250,13 @@ class RandShiftIntensity(RandomizableTransform):
247250

248251
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
249252

250-
def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1) -> None:
253+
def __init__(self, offsets: Union[Tuple[float, float], float], safe: bool = False, prob: float = 0.1) -> None:
251254
"""
252255
Args:
253256
offsets: offset range to randomly shift.
254257
if single number, offset value is picked from (-offsets, offsets).
258+
safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
259+
E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
255260
prob: probability of shift.
256261
"""
257262
RandomizableTransform.__init__(self, prob)
@@ -262,7 +267,7 @@ def __init__(self, offsets: Union[Tuple[float, float], float], prob: float = 0.1
262267
else:
263268
self.offsets = (min(offsets), max(offsets))
264269
self._offset = self.offsets[0]
265-
self._shifter = ShiftIntensity(self._offset)
270+
self._shifter = ShiftIntensity(self._offset, safe)
266271

267272
def randomize(self, data: Optional[Any] = None) -> None:
268273
super().randomize(None)

monai/transforms/intensity/dictionary.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def __init__(
304304
self,
305305
keys: KeysCollection,
306306
offset: float,
307+
safe: bool = False,
307308
factor_key: Optional[str] = None,
308309
meta_keys: Optional[KeysCollection] = None,
309310
meta_key_postfix: str = DEFAULT_POST_FIX,
@@ -314,6 +315,8 @@ def __init__(
314315
keys: keys of the corresponding items to be transformed.
315316
See also: :py:class:`monai.transforms.compose.MapTransform`
316317
offset: offset value to shift the intensity of image.
318+
safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
319+
E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
317320
factor_key: if not None, use it as the key to extract a value from the corresponding
318321
metadata dictionary of `key` at runtime, and multiply the `offset` to shift intensity.
319322
Usually, `IntensityStatsd` transform can pre-compute statistics of intensity values
@@ -336,7 +339,7 @@ def __init__(
336339
if len(self.keys) != len(self.meta_keys):
337340
raise ValueError("meta_keys should have the same length as keys.")
338341
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
339-
self.shifter = ShiftIntensity(offset)
342+
self.shifter = ShiftIntensity(offset, safe)
340343

341344
def __call__(self, data) -> Dict[Hashable, NdarrayOrTensor]:
342345
d = dict(data)
@@ -361,6 +364,7 @@ def __init__(
361364
self,
362365
keys: KeysCollection,
363366
offsets: Union[Tuple[float, float], float],
367+
safe: bool = False,
364368
factor_key: Optional[str] = None,
365369
meta_keys: Optional[KeysCollection] = None,
366370
meta_key_postfix: str = DEFAULT_POST_FIX,
@@ -373,6 +377,8 @@ def __init__(
373377
See also: :py:class:`monai.transforms.compose.MapTransform`
374378
offsets: offset range to randomly shift.
375379
if single number, offset value is picked from (-offsets, offsets).
380+
safe: if `True`, then do safe dtype convert when intensity overflow. default to `False`.
381+
E.g., `[256, -12]` -> `[array(0), array(244)]`. If `True`, then `[256, -12]` -> `[array(255), array(0)]`.
376382
factor_key: if not None, use it as the key to extract a value from the corresponding
377383
metadata dictionary of `key` at runtime, and multiply the random `offset` to shift intensity.
378384
Usually, `IntensityStatsd` transform can pre-compute statistics of intensity values
@@ -399,7 +405,7 @@ def __init__(
399405
if len(self.keys) != len(self.meta_keys):
400406
raise ValueError("meta_keys should have the same length as keys.")
401407
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))
402-
self.shifter = RandShiftIntensity(offsets=offsets, prob=1.0)
408+
self.shifter = RandShiftIntensity(offsets=offsets, safe=safe, prob=1.0)
403409

404410
def set_random_state(
405411
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None

monai/utils/type_conversion.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,10 @@ def _safe_dtype_range(data, dtype):
423423
if data.ndim == 0:
424424
data_bound = (data, data)
425425
else:
426-
data_bound = (min(data), max(data))
426+
if isinstance(data, torch.Tensor):
427+
data_bound = (torch.min(data), torch.max(data))
428+
else:
429+
data_bound = (np.min(data), np.max(data))
427430
if (data_bound[1] > dtype_bound_value[1]) or (data_bound[0] < dtype_bound_value[0]):
428431
if isinstance(data, torch.Tensor):
429432
return torch.clamp(data, dtype_bound_value[0], dtype_bound_value[1])

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ no_implicit_optional = True
159159
# Warns about casting an expression to its inferred type.
160160
warn_redundant_casts = True
161161
# Warns about unneeded # type: ignore comments.
162-
warn_unused_ignores = True
162+
# warn_unused_ignores = True
163163
# Shows a warning when returning a value with type Any from a function declared with a non-Any return type.
164164
warn_return_any = True
165165
# Prohibit equality checks, identity checks, and container checks between non-overlapping types.

tests/test_safe_dtype_range.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
if in_type is not float:
2929
TESTS.append((in_type(np.array(256)), in_type(np.array(255)), np.uint8)) # type: ignore
3030
TESTS.append((in_type(np.array(-12)), in_type(np.array(0)), np.uint8)) # type: ignore
31+
for in_type in TEST_NDARRAYS_ALL:
32+
TESTS.append((in_type(np.array([[256, 255], [-12, 0]])), in_type(np.array([[255, 255], [0, 0]])), np.uint8)) # type: ignore
3133

3234
TESTS_LIST: List[Tuple] = []
3335
for in_type in TEST_NDARRAYS_ALL + (int, float):

0 commit comments

Comments
 (0)