@@ -414,6 +414,7 @@ class HoVerNet(nn.Module):
414414 Medical Image Analysis 2019
415415
416416 https://github.com/vqdang/hover_net
417+ https://pytorch.org/vision/main/models/generated/torchvision.models.resnet50.html
417418
418419 Args:
419420 mode: use original implementation (`HoVerNetMODE.ORIGINAL` or "original") or
@@ -429,10 +430,16 @@ class HoVerNet(nn.Module):
429430 Please note that to get consistent output size, `HoVerNetMode.FAST` mode should be employed.
430431 dropout_prob: dropout rate after each dense layer.
431432 pretrained_url: if specifying, will loaded the pretrained weights downloaded from the url.
432- The weights should be ImageNet pretrained preact-resnet50 weights coming from the referred hover_net
433+ There are two supported forms of weights:
434+ 1. preact-resnet50 weights coming from the referred hover_net
433435 repository, each user is responsible for checking the content of model/datasets and the applicable licenses
434436 and determining if suitable for the intended use. please check the following link for more details:
435437 https://github.com/vqdang/hover_net#data-format
438+ 2. standard resnet50 weights of torchvision. Please check the following link for more details:
439+ https://pytorch.org/vision/main/_modules/torchvision/models/resnet.html#ResNet50_Weights
440+ adapt_standard_resnet: if the pretrained weights of the encoder follow the original format (preact-resnet50), this
441+ value should be `False`. If using the pretrained weights that follow torchvision's standard resnet50 format,
442+ this value should be `True`.
436443 freeze_encoder: whether to freeze the encoder of the network.
437444 """
438445
@@ -450,6 +457,7 @@ def __init__(
450457 decoder_padding : bool = False ,
451458 dropout_prob : float = 0.0 ,
452459 pretrained_url : Optional [str ] = None ,
460+ adapt_standard_resnet : bool = False ,
453461 freeze_encoder : bool = False ,
454462 ) -> None :
455463
@@ -555,7 +563,11 @@ def __init__(
555563 nn .init .constant_ (torch .as_tensor (m .bias ), 0 )
556564
557565 if pretrained_url is not None :
558- _load_pretrained_encoder (self , pretrained_url )
566+ if adapt_standard_resnet :
567+ weights = _remap_standard_resnet_model (pretrained_url )
568+ else :
569+ weights = _remap_preact_resnet_model (pretrained_url )
570+ _load_pretrained_encoder (self , weights )
559571
560572 def forward (self , x : torch .Tensor ) -> Dict [str , torch .Tensor ]:
561573
@@ -588,7 +600,18 @@ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
588600 return output
589601
590602
591- def _load_pretrained_encoder (model : nn .Module , model_url : str ):
603+ def _load_pretrained_encoder (model : nn .Module , state_dict : Union [OrderedDict , Dict ]):
604+
605+ model_dict = model .state_dict ()
606+ state_dict = {
607+ k : v for k , v in state_dict .items () if (k in model_dict ) and (model_dict [k ].shape == state_dict [k ].shape )
608+ }
609+
610+ model_dict .update (state_dict )
611+ model .load_state_dict (model_dict )
612+
613+
614+ def _remap_preact_resnet_model (model_url : str ):
592615
593616 pattern_conv0 = re .compile (r"^(conv0\.\/)(.+)$" )
594617 pattern_block = re .compile (r"^(d\d+)\.(.+)$" )
@@ -614,12 +637,59 @@ def _load_pretrained_encoder(model: nn.Module, model_url: str):
614637 if "upsample2x" in key :
615638 del state_dict [key ]
616639
617- model_dict = model .state_dict ()
618- state_dict = {
619- k : v for k , v in state_dict .items () if (k in model_dict ) and (model_dict [k ].shape == state_dict [k ].shape )
620- }
621- model_dict .update (state_dict )
622- model .load_state_dict (model_dict )
640+ return state_dict
641+
642+
643+ def _remap_standard_resnet_model (model_url : str ):
644+
645+ pattern_conv0 = re .compile (r"^conv1\.(.+)$" )
646+ pattern_bn1 = re .compile (r"^bn1\.(.+)$" )
647+ pattern_block = re .compile (r"^layer(\d+)\.(\d+)\.(.+)$" )
648+ # bn3 to next denselayer's preact/bn
649+ pattern_block_bn3 = re .compile (r"^(res_blocks.d\d+\.layers\.denselayer_)(\d+)\.layers\.bn3\.(.+)$" )
650+ # bn1, bn2 to conv1/bn, conv2/bn
651+ pattern_block_bn = re .compile (r"^(res_blocks.d\d+\.layers\.denselayer_\d+\.layers)\.bn(\d+)\.(.+)$" )
652+ pattern_downsample0 = re .compile (r"^(res_blocks.d\d+).+\.downsample\.0\.(.+)" )
653+ pattern_downsample1 = re .compile (r"^(res_blocks.d\d+).+\.downsample\.1\.(.+)" )
654+ # download the pretrained weights into torch hub's default dir
655+ weights_dir = os .path .join (torch .hub .get_dir (), "resnet50.pth" )
656+ download_url (model_url , fuzzy = True , filepath = weights_dir , progress = False )
657+ state_dict = torch .load (weights_dir , map_location = None )
658+
659+ for key in list (state_dict .keys ()):
660+ new_key = None
661+ if pattern_conv0 .match (key ):
662+ new_key = re .sub (pattern_conv0 , r"conv0.conv.\1" , key )
663+ elif pattern_bn1 .match (key ):
664+ new_key = re .sub (pattern_bn1 , r"conv0.bn.\1" , key )
665+ elif pattern_block .match (key ):
666+ new_key = re .sub (
667+ pattern_block ,
668+ lambda s : "res_blocks.d"
669+ + str (int (s .group (1 )) - 1 )
670+ + ".layers.denselayer_"
671+ + s .group (2 )
672+ + ".layers."
673+ + s .group (3 ),
674+ key ,
675+ )
676+ if pattern_block_bn3 .match (new_key ):
677+ new_key = re .sub (
678+ pattern_block_bn3 ,
679+ lambda s : s .group (1 ) + str (int (s .group (2 )) + 1 ) + ".layers.preact/bn." + s .group (3 ),
680+ new_key ,
681+ )
682+ elif pattern_block_bn .match (new_key ):
683+ new_key = re .sub (pattern_block_bn , r"\1.conv\2/bn.\3" , new_key )
684+ elif pattern_downsample0 .match (new_key ):
685+ new_key = re .sub (pattern_downsample0 , r"\1.shortcut.\2" , new_key )
686+ elif pattern_downsample1 .match (new_key ):
687+ new_key = re .sub (pattern_downsample1 , r"\1.bna_block.bn.\2" , new_key )
688+ if new_key :
689+ state_dict [new_key ] = state_dict [key ]
690+ del state_dict [key ]
691+
692+ return state_dict
623693
624694
625695Hovernet = HoVernet = HoverNet = HoVerNet
0 commit comments