Skip to content

Commit f239825

Browse files
SwinUNETR refactored img_size parameter and removed checkpointing dep… (#7093)
### Description Make two changes for the SwinUNETR network: - The image_size parameter does not seem to have an effect apart from checking shape compatibility in the beginning. This is now expressed in the docstring and the parameter will be deprecated in the future. Instead the shape compatibility is checked during the forward pass on the actual shape - newer pytorch versions accept a parameter [use_reentrant](https://pytorch.org/docs/2.1/checkpoint.html). The old default of True is deprecated in favor of True. This PR sets the parameter to true and therefore adopts the recommended value and removes the warning. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: John Zielke <john.zielke@snkeos.com>
1 parent 100db27 commit f239825

2 files changed

Lines changed: 34 additions & 10 deletions

File tree

monai/networks/nets/swin_unetr.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
import torch.nn.functional as F
2121
import torch.utils.checkpoint as checkpoint
2222
from torch.nn import LayerNorm
23+
from typing_extensions import Final
2324

2425
from monai.networks.blocks import MLPBlock as Mlp
2526
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
2627
from monai.networks.layers import DropPath, trunc_normal_
2728
from monai.utils import ensure_tuple_rep, look_up_option, optional_import
29+
from monai.utils.deprecate_utils import deprecated_arg
2830

2931
rearrange, _ = optional_import("einops", name="rearrange")
3032

@@ -49,6 +51,15 @@ class SwinUNETR(nn.Module):
4951
<https://arxiv.org/abs/2201.01266>"
5052
"""
5153

54+
patch_size: Final[int] = 2
55+
56+
@deprecated_arg(
57+
name="img_size",
58+
since="1.3",
59+
removed="1.5",
60+
msg_suffix="The img_size argument is not required anymore and "
61+
"checks on the input size are run during forward().",
62+
)
5263
def __init__(
5364
self,
5465
img_size: Sequence[int] | int,
@@ -69,7 +80,10 @@ def __init__(
6980
) -> None:
7081
"""
7182
Args:
72-
img_size: dimension of input image.
83+
img_size: spatial dimension of input image.
84+
This argument is only used for checking that the input image size is divisible by the patch size.
85+
The tensor passed to forward() can have a dynamic shape as long as its spatial dimensions are divisible by 2**5.
86+
It will be removed in an upcoming version.
7387
in_channels: dimension of input channels.
7488
out_channels: dimension of output channels.
7589
feature_size: dimension of network feature size.
@@ -103,16 +117,13 @@ def __init__(
103117
super().__init__()
104118

105119
img_size = ensure_tuple_rep(img_size, spatial_dims)
106-
patch_size = ensure_tuple_rep(2, spatial_dims)
120+
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
107121
window_size = ensure_tuple_rep(7, spatial_dims)
108122

109123
if spatial_dims not in (2, 3):
110124
raise ValueError("spatial dimension should be 2 or 3.")
111125

112-
for m, p in zip(img_size, patch_size):
113-
for i in range(5):
114-
if m % np.power(p, i + 1) != 0:
115-
raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.")
126+
self._check_input_size(img_size)
116127

117128
if not (0 <= drop_rate <= 1):
118129
raise ValueError("dropout rate should be between 0 and 1.")
@@ -132,7 +143,7 @@ def __init__(
132143
in_chans=in_channels,
133144
embed_dim=feature_size,
134145
window_size=window_size,
135-
patch_size=patch_size,
146+
patch_size=patch_sizes,
136147
depths=depths,
137148
num_heads=num_heads,
138149
mlp_ratio=4.0,
@@ -297,7 +308,20 @@ def load_from(self, weights):
297308
weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
298309
)
299310

311+
@torch.jit.unused
312+
def _check_input_size(self, spatial_shape):
313+
img_size = np.array(spatial_shape)
314+
remainder = (img_size % np.power(self.patch_size, 5)) > 0
315+
if remainder.any():
316+
wrong_dims = (np.where(remainder)[0] + 2).tolist()
317+
raise ValueError(
318+
f"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})"
319+
f" must be divisible by {self.patch_size}**5."
320+
)
321+
300322
def forward(self, x_in):
323+
if not torch.jit.is_scripting():
324+
self._check_input_size(x_in.shape[2:])
301325
hidden_states_out = self.swinViT(x_in, self.normalize)
302326
enc0 = self.encoder1(x_in)
303327
enc1 = self.encoder2(hidden_states_out[0])
@@ -669,12 +693,12 @@ def load_from(self, weights, n_block, layer):
669693
def forward(self, x, mask_matrix):
670694
shortcut = x
671695
if self.use_checkpoint:
672-
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
696+
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix, use_reentrant=False)
673697
else:
674698
x = self.forward_part1(x, mask_matrix)
675699
x = shortcut + self.drop_path(x)
676700
if self.use_checkpoint:
677-
x = x + checkpoint.checkpoint(self.forward_part2, x)
701+
x = x + checkpoint.checkpoint(self.forward_part2, x, use_reentrant=False)
678702
else:
679703
x = x + self.forward_part2(x)
680704
return x

runtests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ do
261261
doBlackFormat=true
262262
doIsortFormat=true
263263
doFlake8Format=true
264-
doPylintFormat=true
264+
# doPylintFormat=true # https://github.com/Project-MONAI/MONAI/issues/7094
265265
doRuffFormat=true
266266
doCopyRight=true
267267
;;

0 commit comments

Comments
 (0)