Skip to content

Commit c45ec1c

Browse files
authored
update forward hook fn
1 parent 8e6a834 commit c45ec1c

1 file changed

Lines changed: 11 additions & 5 deletions

File tree

deepspeed/runtime/engine.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch.distributed as dist
1111
from collections import OrderedDict
1212
from shutil import copyfile
13+
import re
1314

1415
from torch.nn.modules import Module
1516
from torch.distributed.distributed_c10d import _get_global_rank
@@ -224,7 +225,7 @@ def __init__(self,
224225

225226

226227
def register_forward_hook(self, layers_to_hook: list, layer_name_pattern: str = "transformerlayer"):
227-
self.layer_name_pattern = layer_name_pattern
228+
self.layer_name_pattern = re.compile(layer_name_pattern, re.IGNORECASE)
228229
self.layers_to_hook = layers_to_hook
229230

230231
if self.hooks:
@@ -240,12 +241,17 @@ def hook_fn(module, input, output):
240241
if int(key) not in self.layers_to_hook:
241242
return
242243
else:
243-
key = module
244+
key = module.__class__.__name__
244245
self.layer_outputs[key] = output
245246

246-
for name, layer in self.module._modules.items():
247-
if hasattr(layer, "__class__") and self.layer_name_pattern in layer.__class__.__name__.lower():
248-
self.hooks.append(layer.register_forward_hook(hook_fn))
247+
def get_all_layers(net):
248+
for name, layer in net._modules.items():
249+
if isinstance(layer, torch.nn.Sequential):
250+
get_all_layers(layer)
251+
elif hasattr(layer, "__class__") and self.layer_name_pattern.search(layer.__class__.__name__.lower()):
252+
self.hooks.append(layer.register_forward_hook(hook_fn))
253+
254+
get_all_layers(self.module)
249255

250256
def get_batch_info(self):
251257
""" Get all training batch related settings.

0 commit comments

Comments
 (0)