1010import torch .distributed as dist
1111from collections import OrderedDict
1212from shutil import copyfile
13+ import re
1314
1415from torch .nn .modules import Module
1516from torch .distributed .distributed_c10d import _get_global_rank
@@ -224,7 +225,7 @@ def __init__(self,
224225
225226
226227 def register_forward_hook (self , layers_to_hook : list , layer_name_pattern : str = "transformerlayer" ):
227- self .layer_name_pattern = layer_name_pattern
228+ self .layer_name_pattern = re . compile ( layer_name_pattern , re . IGNORECASE )
228229 self .layers_to_hook = layers_to_hook
229230
230231 if self .hooks :
@@ -240,12 +241,17 @@ def hook_fn(module, input, output):
240241 if int (key ) not in self .layers_to_hook :
241242 return
242243 else :
243- key = module
244+ key = module . __class__ . __name__
244245 self .layer_outputs [key ] = output
245246
246- for name , layer in self .module ._modules .items ():
247- if hasattr (layer , "__class__" ) and self .layer_name_pattern in layer .__class__ .__name__ .lower ():
248- self .hooks .append (layer .register_forward_hook (hook_fn ))
247+ def get_all_layers (net ):
248+ for name , layer in net ._modules .items ():
249+ if isinstance (layer , torch .nn .Sequential ):
250+ get_all_layers (layer )
251+ elif hasattr (layer , "__class__" ) and self .layer_name_pattern .search (layer .__class__ .__name__ .lower ()):
252+ self .hooks .append (layer .register_forward_hook (hook_fn ))
253+
254+ get_all_layers (self .module )
249255
250256 def get_batch_info (self ):
251257 """ Get all training batch related settings.
0 commit comments