Skip to content

Commit fe733b0

Browse files
Pkaps25Peter Kaplinsky
andauthored
Propagate kernel size through attention Attention-UNet (#7734)
Fixes #7726. ### Description Passes the `kernel_size` parameter to `ConvBlocks` within Attention UNet, creating a net with the expected number of parameters. Using the example in #7726 on this branch: ``` from monai.networks.nets import AttentionUnet model = AttentionUnet( spatial_dims = 2, in_channels = 1, out_channels = 1, channels = (2, 4, 8, 16), strides = (2,2,2), kernel_size = 5, up_kernel_size = 5 ) ``` outputs the expected values: ``` Total params: 18,846 Trainable params: 18,846 Non-trainable params: 0 Total mult-adds (M): 0.37 ``` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Peter Kaplinsky <peterkaplinsky@gmail.com> Co-authored-by: Peter Kaplinsky <peterkaplinsky@gmail.com>
1 parent e1a69b0 commit fe733b0

2 files changed

Lines changed: 30 additions & 2 deletions

File tree

monai/networks/nets/attentionunet.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(
2929
spatial_dims: int,
3030
in_channels: int,
3131
out_channels: int,
32-
kernel_size: int = 3,
32+
kernel_size: Sequence[int] | int = 3,
3333
strides: int = 1,
3434
dropout=0.0,
3535
):
@@ -219,7 +219,13 @@ def __init__(
219219
self.kernel_size = kernel_size
220220
self.dropout = dropout
221221

222-
head = ConvBlock(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=channels[0], dropout=dropout)
222+
head = ConvBlock(
223+
spatial_dims=spatial_dims,
224+
in_channels=in_channels,
225+
out_channels=channels[0],
226+
dropout=dropout,
227+
kernel_size=self.kernel_size,
228+
)
223229
reduce_channels = Convolution(
224230
spatial_dims=spatial_dims,
225231
in_channels=channels[0],
@@ -245,6 +251,7 @@ def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module:
245251
out_channels=channels[1],
246252
strides=strides[0],
247253
dropout=self.dropout,
254+
kernel_size=self.kernel_size,
248255
),
249256
subblock,
250257
),
@@ -271,6 +278,7 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) -
271278
out_channels=out_channels,
272279
strides=strides,
273280
dropout=self.dropout,
281+
kernel_size=self.kernel_size,
274282
),
275283
up_kernel_size=self.up_kernel_size,
276284
strides=strides,

tests/test_attentionunet.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,17 @@
1414
import unittest
1515

1616
import torch
17+
import torch.nn as nn
1718

1819
import monai.networks.nets.attentionunet as att
1920
from tests.utils import skip_if_no_cuda, skip_if_quick
2021

2122

23+
def get_net_parameters(net: nn.Module) -> int:
24+
"""Returns the total number of parameters in a Module."""
25+
return sum(param.numel() for param in net.parameters())
26+
27+
2228
class TestAttentionUnet(unittest.TestCase):
2329

2430
def test_attention_block(self):
@@ -50,6 +56,20 @@ def test_attentionunet(self):
5056
self.assertEqual(output.shape[0], input.shape[0])
5157
self.assertEqual(output.shape[1], 2)
5258

59+
def test_attentionunet_kernel_size(self):
60+
args_dict = {
61+
"spatial_dims": 2,
62+
"in_channels": 1,
63+
"out_channels": 2,
64+
"channels": (3, 4, 5),
65+
"up_kernel_size": 5,
66+
"strides": (1, 2),
67+
}
68+
model_a = att.AttentionUnet(**args_dict, kernel_size=5)
69+
model_b = att.AttentionUnet(**args_dict, kernel_size=7)
70+
self.assertEqual(get_net_parameters(model_a), 3534)
71+
self.assertEqual(get_net_parameters(model_b), 5574)
72+
5373
@skip_if_no_cuda
5474
def test_attentionunet_gpu(self):
5575
for dims in [2, 3]:

0 commit comments

Comments
 (0)