Skip to content

Commit 825021d

Browse files
authored
EnsureTyped flexible dtype (#7104)
Fixes #7102 ### Description Make dtype in `EnsureTyped` configurable as different dtypes for difference keys ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com>
1 parent 2b0a95e commit 825021d

3 files changed

Lines changed: 28 additions & 15 deletions

File tree

monai/transforms/utility/array.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ class EnsureType(Transform):
424424
def __init__(
425425
self,
426426
data_type: str = "tensor",
427-
dtype: DtypeLike | torch.dtype | None = None,
427+
dtype: DtypeLike | torch.dtype = None,
428428
device: torch.device | None = None,
429429
wrap_sequence: bool = True,
430430
track_meta: bool | None = None,
@@ -435,13 +435,14 @@ def __init__(
435435
self.wrap_sequence = wrap_sequence
436436
self.track_meta = get_track_meta() if track_meta is None else bool(track_meta)
437437

438-
def __call__(self, data: NdarrayOrTensor):
438+
def __call__(self, data: NdarrayOrTensor, dtype: DtypeLike | torch.dtype = None):
439439
"""
440440
Args:
441441
data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
442442
will ensure Tensor, Numpy array, float, int, bool as Tensors or numpy arrays, strings and
443443
objects keep the original. for dictionary, list or tuple, ensure every item as expected type
444444
if applicable and `wrap_sequence=False`.
445+
dtype: target data content type to convert, for example: np.float32, torch.float, etc.
445446
446447
"""
447448
if self.data_type == "tensor":
@@ -452,7 +453,7 @@ def __call__(self, data: NdarrayOrTensor):
452453
out, *_ = convert_data_type(
453454
data=data,
454455
output_type=output_type, # type: ignore
455-
dtype=self.dtype,
456+
dtype=self.dtype if dtype is None else dtype,
456457
device=self.device,
457458
wrap_sequence=self.wrap_sequence,
458459
)

monai/transforms/utility/dictionary.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def __init__(
488488
self,
489489
keys: KeysCollection,
490490
data_type: str = "tensor",
491-
dtype: DtypeLike | torch.dtype = None,
491+
dtype: Sequence[DtypeLike | torch.dtype] | DtypeLike | torch.dtype = None,
492492
device: torch.device | None = None,
493493
wrap_sequence: bool = True,
494494
track_meta: bool | None = None,
@@ -500,6 +500,7 @@ def __init__(
500500
See also: :py:class:`monai.transforms.compose.MapTransform`
501501
data_type: target data type to convert, should be "tensor" or "numpy".
502502
dtype: target data content type to convert, for example: np.float32, torch.float, etc.
503+
It also can be a sequence of dtype, each element corresponds to a key in ``keys``.
503504
device: for Tensor data type, specify the target device.
504505
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
505506
E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
@@ -508,14 +509,15 @@ def __init__(
508509
allow_missing_keys: don't raise exception if key is missing.
509510
"""
510511
super().__init__(keys, allow_missing_keys)
512+
self.dtype = ensure_tuple_rep(dtype, len(self.keys))
511513
self.converter = EnsureType(
512-
data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta
514+
data_type=data_type, device=device, wrap_sequence=wrap_sequence, track_meta=track_meta
513515
)
514516

515517
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
516518
d = dict(data)
517-
for key in self.key_iterator(d):
518-
d[key] = self.converter(d[key])
519+
for key, dtype in self.key_iterator(d, self.dtype):
520+
d[key] = self.converter(d[key], dtype)
519521
return d
520522

521523

tests/test_ensure_typed.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,24 @@ def test_dict(self):
8686
"extra": None,
8787
}
8888
for dtype in ("tensor", "numpy"):
89-
result = EnsureTyped(keys="data", data_type=dtype, device="cpu")({"data": test_data})["data"]
90-
self.assertTrue(isinstance(result, dict))
91-
self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray))
92-
assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False)
93-
self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray))
94-
assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False)
95-
self.assertEqual(result["meta"]["path"], "temp/test")
96-
self.assertEqual(result["extra"], None)
89+
trans = EnsureTyped(keys=["data", "label"], data_type=dtype, dtype=[np.float32, np.int8], device="cpu")(
90+
{"data": test_data, "label": test_data}
91+
)
92+
for key in ("data", "label"):
93+
result = trans[key]
94+
self.assertTrue(isinstance(result, dict))
95+
self.assertTrue(isinstance(result["img"], torch.Tensor if dtype == "tensor" else np.ndarray))
96+
self.assertTrue(isinstance(result["meta"]["size"], torch.Tensor if dtype == "tensor" else np.ndarray))
97+
self.assertEqual(result["meta"]["path"], "temp/test")
98+
self.assertEqual(result["extra"], None)
99+
assert_allclose(result["img"], torch.as_tensor([1.0, 2.0]), type_test=False)
100+
assert_allclose(result["meta"]["size"], torch.as_tensor([1, 2, 3]), type_test=False)
101+
if dtype == "numpy":
102+
self.assertTrue(trans["data"]["img"].dtype == np.float32)
103+
self.assertTrue(trans["label"]["img"].dtype == np.int8)
104+
else:
105+
self.assertTrue(trans["data"]["img"].dtype == torch.float32)
106+
self.assertTrue(trans["label"]["img"].dtype == torch.int8)
97107

98108

99109
if __name__ == "__main__":

0 commit comments

Comments
 (0)