|
21 | 21 | import torch |
22 | 22 | import torch.nn as nn |
23 | 23 |
|
| 24 | +from monai.networks.blocks.encoder import BaseEncoder |
24 | 25 | from monai.networks.layers.factories import Conv, Norm, Pool |
25 | 26 | from monai.networks.layers.utils import get_pool_layer |
26 | 27 | from monai.utils import ensure_tuple_rep |
|
45 | 46 | "resnet200", |
46 | 47 | ] |
47 | 48 |
|
| 49 | + |
| 50 | +resnet_params = { |
| 51 | + # model_name: (block, layers, shortcut_type, bias_downsample, datasets23) |
| 52 | + "resnet10": ("basic", [1, 1, 1, 1], "B", False, True), |
| 53 | + "resnet18": ("basic", [2, 2, 2, 2], "A", True, True), |
| 54 | + "resnet34": ("basic", [3, 4, 6, 3], "A", True, True), |
| 55 | + "resnet50": ("bottleneck", [3, 4, 6, 3], "B", False, True), |
| 56 | + "resnet101": ("bottleneck", [3, 4, 23, 3], "B", False, False), |
| 57 | + "resnet152": ("bottleneck", [3, 8, 36, 3], "B", False, False), |
| 58 | + "resnet200": ("bottleneck", [3, 24, 36, 3], "B", False, False), |
| 59 | +} |
| 60 | + |
| 61 | + |
48 | 62 | logger = logging.getLogger(__name__) |
49 | 63 |
|
50 | 64 |
|
@@ -335,6 +349,120 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: |
335 | 349 | return x |
336 | 350 |
|
337 | 351 |
|
| 352 | +class ResNetFeatures(ResNet): |
| 353 | + |
| 354 | + def __init__(self, model_name: str, pretrained: bool = True, spatial_dims: int = 3, in_channels: int = 1) -> None: |
| 355 | + """Initialize resnet18 to resnet200 models as a backbone, the backbone can be used as an encoder for |
| 356 | + segmentation and objection models. |
| 357 | +
|
| 358 | + Compared with the class `ResNet`, the only different place is the forward function. |
| 359 | +
|
| 360 | + Args: |
| 361 | + model_name: name of model to initialize, can be from [resnet10, ..., resnet200]. |
| 362 | + pretrained: whether to initialize pretrained MedicalNet weights, |
| 363 | + only available for spatial_dims=3 and in_channels=1. |
| 364 | + spatial_dims: number of spatial dimensions of the input image. |
| 365 | + in_channels: number of input channels for first convolutional layer. |
| 366 | + """ |
| 367 | + if model_name not in resnet_params: |
| 368 | + model_name_string = ", ".join(resnet_params.keys()) |
| 369 | + raise ValueError(f"invalid model_name {model_name} found, must be one of {model_name_string} ") |
| 370 | + |
| 371 | + block, layers, shortcut_type, bias_downsample, datasets23 = resnet_params[model_name] |
| 372 | + |
| 373 | + super().__init__( |
| 374 | + block=block, |
| 375 | + layers=layers, |
| 376 | + block_inplanes=get_inplanes(), |
| 377 | + spatial_dims=spatial_dims, |
| 378 | + n_input_channels=in_channels, |
| 379 | + conv1_t_stride=2, |
| 380 | + shortcut_type=shortcut_type, |
| 381 | + feed_forward=False, |
| 382 | + bias_downsample=bias_downsample, |
| 383 | + ) |
| 384 | + if pretrained: |
| 385 | + if spatial_dims == 3 and in_channels == 1: |
| 386 | + _load_state_dict(self, model_name, datasets23=datasets23) |
| 387 | + else: |
| 388 | + raise ValueError("Pretrained resnet models are only available for in_channels=1 and spatial_dims=3.") |
| 389 | + |
| 390 | + def forward(self, inputs: torch.Tensor): |
| 391 | + """ |
| 392 | + Args: |
| 393 | + inputs: input should have spatially N dimensions |
| 394 | + ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`. |
| 395 | +
|
| 396 | + Returns: |
| 397 | + a list of torch Tensors. |
| 398 | + """ |
| 399 | + x = self.conv1(inputs) |
| 400 | + x = self.bn1(x) |
| 401 | + x = self.relu(x) |
| 402 | + |
| 403 | + features = [] |
| 404 | + features.append(x) |
| 405 | + |
| 406 | + if not self.no_max_pool: |
| 407 | + x = self.maxpool(x) |
| 408 | + |
| 409 | + x = self.layer1(x) |
| 410 | + features.append(x) |
| 411 | + |
| 412 | + x = self.layer2(x) |
| 413 | + features.append(x) |
| 414 | + |
| 415 | + x = self.layer3(x) |
| 416 | + features.append(x) |
| 417 | + |
| 418 | + x = self.layer4(x) |
| 419 | + features.append(x) |
| 420 | + |
| 421 | + return features |
| 422 | + |
| 423 | + |
| 424 | +class ResNetEncoder(ResNetFeatures, BaseEncoder): |
| 425 | + """Wrap the original resnet to an encoder for flexible-unet.""" |
| 426 | + |
| 427 | + backbone_names = ["resnet10", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnet200"] |
| 428 | + |
| 429 | + @classmethod |
| 430 | + def get_encoder_parameters(cls) -> list[dict]: |
| 431 | + """Get the initialization parameter for resnet backbones.""" |
| 432 | + parameter_list = [] |
| 433 | + for backbone_name in cls.backbone_names: |
| 434 | + parameter_list.append( |
| 435 | + {"model_name": backbone_name, "pretrained": True, "spatial_dims": 3, "in_channels": 1} |
| 436 | + ) |
| 437 | + return parameter_list |
| 438 | + |
| 439 | + @classmethod |
| 440 | + def num_channels_per_output(cls) -> list[tuple[int, ...]]: |
| 441 | + """Get number of resnet backbone output feature maps channel.""" |
| 442 | + return [ |
| 443 | + (64, 64, 128, 256, 512), |
| 444 | + (64, 64, 128, 256, 512), |
| 445 | + (64, 64, 128, 256, 512), |
| 446 | + (64, 256, 512, 1024, 2048), |
| 447 | + (64, 256, 512, 1024, 2048), |
| 448 | + (64, 256, 512, 1024, 2048), |
| 449 | + (64, 256, 512, 1024, 2048), |
| 450 | + ] |
| 451 | + |
| 452 | + @classmethod |
| 453 | + def num_outputs(cls) -> list[int]: |
| 454 | + """Get number of resnet backbone output feature maps. |
| 455 | +
|
| 456 | + Since every backbone contains the same 5 output feature maps, the number list should be `[5] * 7`. |
| 457 | + """ |
| 458 | + return [5] * 7 |
| 459 | + |
| 460 | + @classmethod |
| 461 | + def get_encoder_names(cls) -> list[str]: |
| 462 | + """Get names of resnet backbones.""" |
| 463 | + return cls.backbone_names |
| 464 | + |
| 465 | + |
338 | 466 | def _resnet( |
339 | 467 | arch: str, |
340 | 468 | block: type[ResNetBlock | ResNetBottleneck], |
@@ -477,7 +605,7 @@ def resnet200(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> |
477 | 605 |
|
478 | 606 | def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", datasets23: bool = True): |
479 | 607 | """ |
480 | | - Donwlad resnet pretrained weights from https://huggingface.co/TencentMedicalNet |
| 608 | + Download resnet pretrained weights from https://huggingface.co/TencentMedicalNet |
481 | 609 |
|
482 | 610 | Args: |
483 | 611 | resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200 |
@@ -533,11 +661,24 @@ def get_pretrained_resnet_medicalnet(resnet_depth: int, device: str = "cpu", dat |
533 | 661 | def get_medicalnet_pretrained_resnet_args(resnet_depth: int): |
534 | 662 | """ |
535 | 663 | Return correct shortcut_type and bias_downsample |
536 | | - for pretrained MedicalNet weights according to resnet depth |
| 664 | + for pretrained MedicalNet weights according to resnet depth. |
537 | 665 | """ |
538 | 666 | # After testing |
539 | 667 | # False: 10, 50, 101, 152, 200 |
540 | 668 | # Any: 18, 34 |
541 | 669 | bias_downsample = -1 if resnet_depth in [18, 34] else 0 # 18, 10, 34 |
542 | 670 | shortcut_type = "A" if resnet_depth in [18, 34] else "B" |
543 | 671 | return bias_downsample, shortcut_type |
| 672 | + |
| 673 | + |
| 674 | +def _load_state_dict(model: nn.Module, model_name: str, datasets23: bool = True) -> None: |
| 675 | + search_res = re.search(r"resnet(\d+)", model_name) |
| 676 | + if search_res: |
| 677 | + resnet_depth = int(search_res.group(1)) |
| 678 | + datasets23 = model_name.endswith("_23datasets") |
| 679 | + else: |
| 680 | + raise ValueError("model_name argument should contain resnet depth. Example: resnet18 or resnet18_23datasets.") |
| 681 | + |
| 682 | + model_state_dict = get_pretrained_resnet_medicalnet(resnet_depth, device="cpu", datasets23=datasets23) |
| 683 | + model_state_dict = {key.replace("module.", ""): value for key, value in model_state_dict.items()} |
| 684 | + model.load_state_dict(model_state_dict) |
0 commit comments