Skip to content

Commit a09d2a2

Browse files
5717 Enhance hovernet to use standard resnet50's weights (#5688)
Signed-off-by: Yiheng Wang <vennw@nvidia.com> Fixes #5717 . ### Description This PR adds the support to load torchvision's resnet50 weights. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Yiheng Wang <vennw@nvidia.com>
1 parent f170e0f commit a09d2a2

1 file changed

Lines changed: 79 additions & 9 deletions

File tree

monai/networks/nets/hovernet.py

Lines changed: 79 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ class HoVerNet(nn.Module):
414414
Medical Image Analysis 2019
415415
416416
https://github.com/vqdang/hover_net
417+
https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html
417418
418419
Args:
419420
mode: use original implementation (`HoVerNetMODE.ORIGINAL` or "original") or
@@ -429,10 +430,16 @@ class HoVerNet(nn.Module):
429430
Please note that to get consistent output size, `HoVerNetMode.FAST` mode should be employed.
430431
dropout_prob: dropout rate after each dense layer.
431432
pretrained_url: if specifying, will loaded the pretrained weights downloaded from the url.
432-
The weights should be ImageNet pretrained preact-resnet50 weights coming from the referred hover_net
433+
There are two supported forms of weights:
434+
1. preact-resnet50 weights coming from the referred hover_net
433435
repository, each user is responsible for checking the content of model/datasets and the applicable licenses
434436
and determining if suitable for the intended use. please check the following link for more details:
435437
https://github.com/vqdang/hover_net#data-format
438+
2. standard resnet50 weights of torchvision. Please check the following link for more details:
439+
https://pytorch.org/vision/main/_modules/torchvision/models/resnet.html#ResNet50_Weights
440+
adapt_standard_resnet: if the pretrained weights of the encoder follow the original format (preact-resnet50), this
441+
value should be `False`. If using the pretrained weights that follow torchvision's standard resnet50 format,
442+
this value should be `True`.
436443
freeze_encoder: whether to freeze the encoder of the network.
437444
"""
438445

@@ -450,6 +457,7 @@ def __init__(
450457
decoder_padding: bool = False,
451458
dropout_prob: float = 0.0,
452459
pretrained_url: Optional[str] = None,
460+
adapt_standard_resnet: bool = False,
453461
freeze_encoder: bool = False,
454462
) -> None:
455463

@@ -555,7 +563,11 @@ def __init__(
555563
nn.init.constant_(torch.as_tensor(m.bias), 0)
556564

557565
if pretrained_url is not None:
558-
_load_pretrained_encoder(self, pretrained_url)
566+
if adapt_standard_resnet:
567+
weights = _remap_standard_resnet_model(pretrained_url)
568+
else:
569+
weights = _remap_preact_resnet_model(pretrained_url)
570+
_load_pretrained_encoder(self, weights)
559571

560572
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
561573

@@ -588,7 +600,18 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
588600
return output
589601

590602

591-
def _load_pretrained_encoder(model: nn.Module, model_url: str):
603+
def _load_pretrained_encoder(model: nn.Module, state_dict: Union[OrderedDict, Dict]):
604+
605+
model_dict = model.state_dict()
606+
state_dict = {
607+
k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)
608+
}
609+
610+
model_dict.update(state_dict)
611+
model.load_state_dict(model_dict)
612+
613+
614+
def _remap_preact_resnet_model(model_url: str):
592615

593616
pattern_conv0 = re.compile(r"^(conv0\.\/)(.+)$")
594617
pattern_block = re.compile(r"^(d\d+)\.(.+)$")
@@ -614,12 +637,59 @@ def _load_pretrained_encoder(model: nn.Module, model_url: str):
614637
if "upsample2x" in key:
615638
del state_dict[key]
616639

617-
model_dict = model.state_dict()
618-
state_dict = {
619-
k: v for k, v in state_dict.items() if (k in model_dict) and (model_dict[k].shape == state_dict[k].shape)
620-
}
621-
model_dict.update(state_dict)
622-
model.load_state_dict(model_dict)
640+
return state_dict
641+
642+
643+
def _remap_standard_resnet_model(model_url: str):
644+
645+
pattern_conv0 = re.compile(r"^conv1\.(.+)$")
646+
pattern_bn1 = re.compile(r"^bn1\.(.+)$")
647+
pattern_block = re.compile(r"^layer(\d+)\.(\d+)\.(.+)$")
648+
# bn3 to next denselayer's preact/bn
649+
pattern_block_bn3 = re.compile(r"^(res_blocks.d\d+\.layers\.denselayer_)(\d+)\.layers\.bn3\.(.+)$")
650+
# bn1, bn2 to conv1/bn, conv2/bn
651+
pattern_block_bn = re.compile(r"^(res_blocks.d\d+\.layers\.denselayer_\d+\.layers)\.bn(\d+)\.(.+)$")
652+
pattern_downsample0 = re.compile(r"^(res_blocks.d\d+).+\.downsample\.0\.(.+)")
653+
pattern_downsample1 = re.compile(r"^(res_blocks.d\d+).+\.downsample\.1\.(.+)")
654+
# download the pretrained weights into torch hub's default dir
655+
weights_dir = os.path.join(torch.hub.get_dir(), "resnet50.pth")
656+
download_url(model_url, fuzzy=True, filepath=weights_dir, progress=False)
657+
state_dict = torch.load(weights_dir, map_location=None)
658+
659+
for key in list(state_dict.keys()):
660+
new_key = None
661+
if pattern_conv0.match(key):
662+
new_key = re.sub(pattern_conv0, r"conv0.conv.\1", key)
663+
elif pattern_bn1.match(key):
664+
new_key = re.sub(pattern_bn1, r"conv0.bn.\1", key)
665+
elif pattern_block.match(key):
666+
new_key = re.sub(
667+
pattern_block,
668+
lambda s: "res_blocks.d"
669+
+ str(int(s.group(1)) - 1)
670+
+ ".layers.denselayer_"
671+
+ s.group(2)
672+
+ ".layers."
673+
+ s.group(3),
674+
key,
675+
)
676+
if pattern_block_bn3.match(new_key):
677+
new_key = re.sub(
678+
pattern_block_bn3,
679+
lambda s: s.group(1) + str(int(s.group(2)) + 1) + ".layers.preact/bn." + s.group(3),
680+
new_key,
681+
)
682+
elif pattern_block_bn.match(new_key):
683+
new_key = re.sub(pattern_block_bn, r"\1.conv\2/bn.\3", new_key)
684+
elif pattern_downsample0.match(new_key):
685+
new_key = re.sub(pattern_downsample0, r"\1.shortcut.\2", new_key)
686+
elif pattern_downsample1.match(new_key):
687+
new_key = re.sub(pattern_downsample1, r"\1.bna_block.bn.\2", new_key)
688+
if new_key:
689+
state_dict[new_key] = state_dict[key]
690+
del state_dict[key]
691+
692+
return state_dict
623693

624694

625695
Hovernet = HoVernet = HoverNet = HoVerNet

0 commit comments

Comments
 (0)