Skip to content

Commit 4e0b179

Browse files
authored
6497 test_script_save utils to raise potential errors (#6498)
Fixes #6497 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent d3e9826 commit 4e0b179

11 files changed

Lines changed: 32 additions & 24 deletions

File tree

.github/workflows/pythonapp-min.yml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,10 @@ jobs:
4949
with:
5050
path: ${{ steps.pip-cache.outputs.dir }}
5151
key: ${{ matrix.os }}-latest-pip-${{ steps.pip-cache.outputs.datew }}
52-
- if: runner.os == 'windows'
53-
name: Install torch cpu from pytorch.org (Windows only)
54-
run: |
55-
python -m pip install torch==1.13.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
5652
- name: Install the dependencies
5753
run: |
5854
# min. requirements
59-
python -m pip install torch==1.13.1
55+
python -m pip install torch --index-url https://download.pytorch.org/whl/cpu
6056
python -m pip install -r requirements-min.txt
6157
python -m pip list
6258
BUILD_MONAI=0 python setup.py develop # no compile of extensions

.pre-commit-config.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ repos:
4343
(?x)(
4444
^versioneer.py|
4545
^monai/_version.py|
46-
^monai/networks/ # no PEP 604 for torchscript tensorrt
46+
^monai/networks/| # no PEP 604 for torchscript tensorrt
47+
^monai/losses/ # no PEP 604 for torchscript tensorrt
4748
)
4849
- id: pyupgrade
4950
args: [--py37-plus, --keep-runtime-typing]

monai/losses/ds_loss.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
from __future__ import annotations
1313

14+
from typing import Union
15+
1416
import torch
1517
import torch.nn.functional as F
1618
from torch.nn.modules.loss import _Loss
@@ -70,13 +72,15 @@ def get_loss(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
7072
target = F.interpolate(target, size=input.shape[2:], mode=self.interp_mode)
7173
return self.loss(input, target) # type: ignore[no-any-return]
7274

73-
def forward(self, input: torch.Tensor | list[torch.Tensor], target: torch.Tensor) -> torch.Tensor:
75+
def forward(self, input: Union[None, torch.Tensor, list[torch.Tensor]], target: torch.Tensor) -> torch.Tensor:
7476
if isinstance(input, (list, tuple)):
7577
weights = self.get_weights(levels=len(input))
7678
loss = torch.tensor(0, dtype=torch.float, device=target.device)
7779
for l in range(len(input)):
7880
loss += weights[l] * self.get_loss(input[l].float(), target)
7981
return loss
82+
if input is None:
83+
raise ValueError("input shouldn't be None.")
8084

8185
return self.loss(input.float(), target) # type: ignore[no-any-return]
8286

monai/losses/focal_loss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import warnings
1515
from collections.abc import Sequence
16+
from typing import Optional
1617

1718
import torch
1819
import torch.nn.functional as F
@@ -154,7 +155,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
154155
ce = i - i * t + max_val + ((-max_val).exp() + (-i - max_val).exp()).log()
155156

156157
if self.weight is not None:
157-
class_weight: torch.Tensor | None = None
158+
class_weight: Optional[torch.Tensor] = None
158159
if isinstance(self.weight, (float, int)):
159160
class_weight = torch.as_tensor([self.weight] * i.size(1))
160161
else:

monai/losses/spatial_mask.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import inspect
1515
import warnings
1616
from collections.abc import Callable
17-
from typing import Any
17+
from typing import Any, Optional
1818

1919
import torch
2020
from torch.nn.modules.loss import _Loss
@@ -47,7 +47,7 @@ def __init__(
4747
if not callable(self.loss):
4848
raise ValueError("The loss function is not callable.")
4949

50-
def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
50+
def forward(self, input: torch.Tensor, target: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
5151
"""
5252
Args:
5353
input: the shape should be BNH[WD].

monai/networks/nets/ahnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import math
1515
from collections.abc import Sequence
16+
from typing import Union
1617

1718
import torch
1819
import torch.nn as nn
@@ -279,7 +280,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
279280
else:
280281
for project_module, pool_module in zip(self.project_modules, self.pool_modules):
281282
interpolate_size = x.shape[2:]
282-
align_corners: bool | None = None
283+
align_corners: Union[bool, None] = None
283284
if self.upsample_mode in ["trilinear", "bilinear"]:
284285
align_corners = True
285286
output = F.interpolate(

monai/networks/nets/dynunet.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,10 @@ def forward(self, x):
269269
out = self.skip_layers(x)
270270
out = self.output_block(out)
271271
if self.training and self.deep_supervision:
272-
out_all = torch.zeros(out.shape[0], len(self.heads) + 1, *out.shape[1:], device=out.device, dtype=out.dtype)
273-
out_all[:, 0] = out
274-
for idx, feature_map in enumerate(self.heads):
275-
out_all[:, idx + 1] = interpolate(feature_map, out.shape[2:])
276-
return out_all
272+
out_all = [out]
273+
for feature_map in self.heads:
274+
out_all.append(interpolate(feature_map, out.shape[2:]))
275+
return torch.stack(out_all, dim=1)
277276
return out
278277

279278
def get_input_block(self):

monai/networks/nets/segresnet_ds.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
from collections.abc import Callable
15+
from typing import Union
1516

1617
import numpy as np
1718
import torch
@@ -387,7 +388,7 @@ def is_valid_shape(self, x):
387388
a = [i % j == 0 for i, j in zip(x.shape[2:], self.shape_factor())]
388389
return all(a)
389390

390-
def _forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]:
391+
def _forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]:
391392
if self.preprocess is not None:
392393
x = self.preprocess(x)
393394

@@ -423,5 +424,5 @@ def _forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]:
423424
# return a list of DS outputs
424425
return outputs
425426

426-
def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]:
427+
def forward(self, x: torch.Tensor) -> Union[None, torch.Tensor, list[torch.Tensor]]:
427428
return self._forward(x)

monai/networks/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,9 @@ def convert_to_onnx(
695695
set_determinism(seed=None)
696696
# compare onnx/ort and PyTorch results
697697
for r1, r2 in zip(torch_out, onnx_out):
698-
torch.testing.assert_allclose(r1.cpu(), r2, rtol=rtol, atol=atol)
698+
if isinstance(r1, torch.Tensor):
699+
assert_fn = torch.testing.assert_close if pytorch_after(1, 11) else torch.testing.assert_allclose
700+
assert_fn(r1.cpu(), convert_to_tensor(r2, track_meta=False), rtol=rtol, atol=atol) # type: ignore
699701

700702
return onnx_model
701703

tests/test_ssim_loss.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818

1919
from monai.losses.ssim_loss import SSIMLoss
2020
from monai.utils import set_determinism
21-
from tests.utils import test_script_save
21+
22+
# from tests.utils import test_script_save
2223

2324

2425
class TestSSIMLoss(unittest.TestCase):
@@ -47,10 +48,10 @@ def test_shape(self):
4748
expected_val = [[0.9121], [0.9971]]
4849
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-4)
4950

50-
def test_script(self):
51-
loss = SSIMLoss(spatial_dims=2)
52-
test_input = torch.ones(2, 2, 16, 16)
53-
test_script_save(loss, test_input, test_input)
51+
# def test_script(self):
52+
# loss = SSIMLoss(spatial_dims=2)
53+
# test_input = torch.ones(2, 2, 16, 16)
54+
# test_script_save(loss, test_input, test_input)
5455

5556

5657
if __name__ == "__main__":

0 commit comments

Comments
 (0)