Skip to content

Commit de2a819

Browse files
authored
Fix AttributeError when using torch.min and max (#8041)
Fixes #8040. ### Description Only return values if got a namedtuple when using torch.min and max ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent cea80a6 commit de2a819

2 files changed

Lines changed: 15 additions & 3 deletions

File tree

monai/transforms/utils_pytorch_numpy_unification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def max(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe
480480
else:
481481
ret = torch.max(x, int(dim), **kwargs) # type: ignore
482482

483-
return ret
483+
return ret[0] if isinstance(ret, tuple) else ret
484484

485485

486486
def mean(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTensor:
@@ -546,7 +546,7 @@ def min(x: NdarrayTensor, dim: int | tuple | None = None, **kwargs) -> NdarrayTe
546546
else:
547547
ret = torch.min(x, int(dim), **kwargs) # type: ignore
548548

549-
return ret
549+
return ret[0] if isinstance(ret, tuple) else ret
550550

551551

552552
def std(x: NdarrayTensor, dim: int | tuple | None = None, unbiased: bool = False) -> NdarrayTensor:

tests/test_utils_pytorch_numpy_unification.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818
from parameterized import parameterized
1919

20-
from monai.transforms.utils_pytorch_numpy_unification import mode, percentile
20+
from monai.transforms.utils_pytorch_numpy_unification import max, min, mode, percentile
2121
from monai.utils import set_determinism
2222
from tests.utils import TEST_NDARRAYS, assert_allclose, skip_if_quick
2323

@@ -27,6 +27,13 @@
2727
TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False])
2828
TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True])
2929

30+
TEST_MIN_MAX = []
31+
for p in TEST_NDARRAYS:
32+
TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, min, p(1)])
33+
TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, min, p([3.1, 3])])
34+
TEST_MIN_MAX.append([p(np.array([1, 2, 3, 4, 4, 5])), {}, max, p(5)])
35+
TEST_MIN_MAX.append([p(np.array([[3.1, 4.1, 4.1, 5.1], [3, 5, 4.1, 5]])), {"dim": 1}, max, p([5.1, 5])])
36+
3037

3138
class TestPytorchNumpyUnification(unittest.TestCase):
3239

@@ -74,6 +81,11 @@ def test_mode(self, array, expected, to_long):
7481
res = mode(array, to_long=to_long)
7582
assert_allclose(res, expected)
7683

84+
@parameterized.expand(TEST_MIN_MAX)
85+
def test_min_max(self, array, input_params, func, expected):
86+
res = func(array, **input_params)
87+
assert_allclose(res, expected, type_test=False)
88+
7789

7890
if __name__ == "__main__":
7991
unittest.main()

0 commit comments

Comments
 (0)