Skip to content

Commit dc58e5c

Browse files
k-sukharevKumoLiu
andauthored
Add ResNet backbones for FlexibleUNet (#7571)
Fixes #7570. ### Description Add ResNet backbones (with option to use pretrained Med3D weights) for FlexibleUNet. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Konstantin Sukharev <k.sukharev@gmail.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent ec6aa33 commit dc58e5c

6 files changed

Lines changed: 222 additions & 117 deletions

File tree

docs/source/networks.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,11 @@ Nets
491491
.. autoclass:: ResNet
492492
:members:
493493

494+
`ResNetFeatures`
495+
~~~~~~~~~~~~~~~~
496+
.. autoclass:: ResNetFeatures
497+
:members:
498+
494499
`SENet`
495500
~~~~~~~
496501
.. autoclass:: SENet

monai/networks/nets/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
ResNet,
6060
ResNetBlock,
6161
ResNetBottleneck,
62+
ResNetEncoder,
63+
ResNetFeatures,
6264
get_medicalnet_pretrained_resnet_args,
6365
get_pretrained_resnet_medicalnet,
6466
resnet10,

monai/networks/nets/flexible_unet.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from monai.networks.layers.utils import get_act_layer
2525
from monai.networks.nets import EfficientNetEncoder
2626
from monai.networks.nets.basic_unet import UpCat
27+
from monai.networks.nets.resnet import ResNetEncoder
2728
from monai.utils import InterpolateMode, optional_import
2829

2930
__all__ = ["FlexibleUNet", "FlexUNet", "FLEXUNET_BACKBONE", "FlexUNetEncoderRegister"]
@@ -78,6 +79,7 @@ def register_class(self, name: type[Any] | str):
7879

7980
FLEXUNET_BACKBONE = FlexUNetEncoderRegister()
8081
FLEXUNET_BACKBONE.register_class(EfficientNetEncoder)
82+
FLEXUNET_BACKBONE.register_class(ResNetEncoder)
8183

8284

8385
class UNetDecoder(nn.Module):
@@ -238,7 +240,7 @@ def __init__(
238240
) -> None:
239241
"""
240242
A flexible implement of UNet, in which the backbone/encoder can be replaced with
241-
any efficient network. Currently the input must have a 2 or 3 spatial dimension
243+
any efficient or residual network. Currently the input must have a 2 or 3 spatial dimension
242244
and the spatial size of each dimension must be a multiple of 32 if is_pad parameter
243245
is False.
244246
Please notice each output of backbone must be 2x downsample in spatial dimension
@@ -248,10 +250,11 @@ def __init__(
248250
Args:
249251
in_channels: number of input channels.
250252
out_channels: number of output channels.
251-
backbone: name of backbones to initialize, only support efficientnet right now,
252-
can be from [efficientnet-b0,..., efficientnet-b8, efficientnet-l2].
253-
pretrained: whether to initialize pretrained ImageNet weights, only available
254-
for spatial_dims=2 and batch norm is used, default to False.
253+
backbone: name of backbones to initialize, only support efficientnet and resnet right now,
254+
can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2, resnet10, ..., resnet200].
255+
pretrained: whether to initialize pretrained weights. ImageNet weights are available for efficient networks
256+
if spatial_dims=2 and batch norm is used. MedicalNet weights are available for residual networks
257+
if spatial_dims=3 and in_channels=1. Default to False.
255258
decoder_channels: number of output channels for all feature maps in decoder.
256259
`len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
257260
to (256, 128, 64, 32, 16).

monai/networks/nets/resnet.py

Lines changed: 143 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch
2222
import torch.nn as nn
2323

24+
from monai.networks.blocks.encoder import BaseEncoder
2425
from monai.networks.layers.factories import Conv, Norm, Pool
2526
from monai.networks.layers.utils import get_pool_layer
2627
from monai.utils import ensure_tuple_rep
@@ -45,6 +46,19 @@
4546
"resnet200",
4647
]
4748

49+
50+
resnet_params = {
51+
# model_name: (block, layers, shortcut_type, bias_downsample, datasets23)
52+
"resnet10": ("basic", [1, 1, 1, 1], "B", False, True),
53+
"resnet18": ("basic", [2, 2, 2, 2], "A", True, True),
54+
"resnet34": ("basic", [3, 4, 6, 3], "A", True, True),
55+
"resnet50": ("bottleneck", [3, 4, 6, 3], "B", False, True),
56+
"resnet101": ("bottleneck", [3, 4, 23, 3], "B", False, False),
57+
"resnet152": ("bottleneck", [3, 8, 36, 3], "B", False, False),
58+
"resnet200": ("bottleneck", [3, 24, 36, 3], "B", False, False),
59+
}
60+
61+
4862
logger = logging.getLogger(__name__)
4963

5064

@@ -335,6 +349,120 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
335349
return x
336350

337351

352+
class ResNetFeatures(ResNet):
353+
354+
def __init__(self, model_name: str, pretrained: bool = True, spatial_dims: int = 3, in_channels: int = 1) -> None:
355+
"""Initialize resnet18 to resnet200 models as a backbone, the backbone can be used as an encoder for
356+
segmentation and objection models.
357+
358+
Compared with the class `ResNet`, the only different place is the forward function.
359+
360+
Args:
361+
model_name: name of model to initialize, can be from [resnet10, ..., resnet200].
362+
pretrained: whether to initialize pretrained MedicalNet weights,
363+
only available for spatial_dims=3 and in_channels=1.
364+
spatial_dims: number of spatial dimensions of the input image.
365+
in_channels: number of input channels for first convolutional layer.
366+
"""
367+
if model_name not in resnet_params:
368+
model_name_string = ", ".join(resnet_params.keys())
369+
raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ")
370+
371+
block, layers, shortcut_type, bias_downsample, datasets23 = resnet_params[model_name]
372+
373+
super().__init__(
374+
block=block,
375+
layers=layers,
376+
block_inplanes=get_inplanes(),
377+
spatial_dims=spatial_dims,
378+
n_input_channels=in_channels,
379+
conv1_t_stride=2,
380+
shortcut_type=shortcut_type,
381+
feed_forward=False,
382+
bias_downsample=bias_downsample,
383+
)
384+
if pretrained:
385+
if spatial_dims == 3 and in_channels == 1:
386+
_load_state_dict(self, model_name, datasets23=datasets23)
387+
else:
388+
raise ValueError("Pretrained resnet models are only available for in_channels=1 and spatial_dims=3.")
389+
390+
def forward(self, inputs: torch.Tensor):
391+
"""
392+
Args:
393+
inputs: input should have spatially N dimensions
394+
``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`.
395+
396+
Returns:
397+
a list of torch Tensors.
398+
"""
399+
x = self.conv1(inputs)
400+
x = self.bn1(x)
401+
x = self.relu(x)
402+
403+
features = []
404+
features.append(x)
405+
406+
if not self.no_max_pool:
407+
x = self.maxpool(x)
408+
409+
x = self.layer1(x)
410+
features.append(x)
411+
412+
x = self.layer2(x)
413+
features.append(x)
414+
415+
x = self.layer3(x)
416+
features.append(x)
417+
418+
x = self.layer4(x)
419+
features.append(x)
420+
421+
return features
422+
423+
424+
class ResNetEncoder(ResNetFeatures, BaseEncoder):
425+
"""Wrap the original resnet to an encoder for flexible-unet."""
426+
427+
backbone_names = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"]
428+
429+
@classmethod
430+
def get_encoder_parameters(cls) -> list[dict]:
431+
"""Get the initialization parameter for resnet backbones."""
432+
parameter_list = []
433+
for backbone_name in cls.backbone_names:
434+
parameter_list.append(
435+
{"model_name": backbone_name, "pretrained": True, "spatial_dims": 3, "in_channels": 1}
436+
)
437+
return parameter_list
438+
439+
@classmethod
440+
def num_channels_per_output(cls) -> list[tuple[int, ...]]:
441+
"""Get number of resnet backbone output feature maps channel."""
442+
return [
443+
(64, 64, 128, 256, 512),
444+
(64, 64, 128, 256, 512),
445+
(64, 64, 128, 256, 512),
446+
(64, 256, 512, 1024, 2048),
447+
(64, 256, 512, 1024, 2048),
448+
(64, 256, 512, 1024, 2048),
449+
(64, 256, 512, 1024, 2048),
450+
]
451+
452+
@classmethod
453+
def num_outputs(cls) -> list[int]:
454+
"""Get number of resnet backbone output feature maps.
455+
456+
Since every backbone contains the same 5 output feature maps, the number list should be `[5] * 7`.
457+
"""
458+
return [5] * 7
459+
460+
@classmethod
461+
def get_encoder_names(cls) -> list[str]:
462+
"""Get names of resnet backbones."""
463+
return cls.backbone_names
464+
465+
338466
def _resnet(
339467
arch: str,
340468
block: type[ResNetBlock | ResNetBottleneck],
@@ -477,7 +605,7 @@ def resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) ->
477605

478606
def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True):
479607
"""
480-
Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet
608+
Download resnet pretrained weights from https://huggingface.co/TencentMedicalNet
481609
482610
Args:
483611
resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200
@@ -533,11 +661,24 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat
533661
def get_medicalnet_pretrained_resnet_args(resnet_depth: int):
534662
"""
535663
Return correct shortcut_type and bias_downsample
536-
for pretrained MedicalNet weights according to resnet depth
664+
for pretrained MedicalNet weights according to resnet depth.
537665
"""
538666
# After testing
539667
# False: 10, 50, 101, 152, 200
540668
# Any: 18, 34
541669
bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34
542670
shortcut_type = "A" if resnet_depth in [18, 34] else "B"
543671
return bias_downsample, shortcut_type
672+
673+
674+
def _load_state_dict(model: nn.Module, model_name: str, datasets23: bool = True) -> None:
675+
search_res = re.search(r"resnet(\d+)", model_name)
676+
if search_res:
677+
resnet_depth = int(search_res.group(1))
678+
datasets23 = model_name.endswith("_23datasets")
679+
else:
680+
raise ValueError("model_name argument should contain resnet depth. Example: resnet18 or resnet18_23datasets.")
681+
682+
model_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device="cpu", datasets23=datasets23)
683+
model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()}
684+
model.load_state_dict(model_state_dict)

0 commit comments

Comments
 (0)