Skip to content

Commit 8e6a834

Browse files
author
server-jack
committed
add forward hook functionality
1 parent 8900aa0 commit 8e6a834

2 files changed

Lines changed: 54 additions & 7 deletions

File tree

deepspeed/runtime/engine.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,35 @@ def __init__(self,
218218
self.flatten = util_ops.flatten
219219
self.unflatten = util_ops.unflatten
220220

221+
self.layer_outputs, self.layers_to_hook, self.hooks = {}, [], []
222+
self.layer_name_pattern = "transformerlayer"
223+
self.register_forward_hook(layers_to_hook=self.layers_to_hook)
224+
225+
226+
def register_forward_hook(self, layers_to_hook: list, layer_name_pattern: str = "transformerlayer"):
227+
self.layer_name_pattern = layer_name_pattern
228+
self.layers_to_hook = layers_to_hook
229+
230+
if self.hooks:
231+
# remove old hooks
232+
for handle in self.hooks:
233+
handle.remove()
234+
235+
def hook_fn(module, input, output):
236+
if hasattr(module, 'layer_number'):
237+
key = module.layer_number
238+
if self.layers_to_hook == "all":
239+
pass
240+
if int(key) not in self.layers_to_hook:
241+
return
242+
else:
243+
key = module
244+
self.layer_outputs[key] = output
245+
246+
for name, layer in self.module._modules.items():
247+
if hasattr(layer, "__class__") and self.layer_name_pattern in layer.__class__.__name__.lower():
248+
self.hooks.append(layer.register_forward_hook(hook_fn))
249+
221250
def get_batch_info(self):
222251
""" Get all training batch related settings.
223252

deepspeed/runtime/pipe/engine.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
LOG_STAGE = -2
3030
DATA_PARALLEL_ID = -2
3131

32-
3332
def is_even(number):
3433
return number % 2 == 0
3534

@@ -76,7 +75,6 @@ def __init__(self, *super_args, **super_kwargs):
7675

7776
# pipeline step for logging
7877
self.log_batch_step_id = -1
79-
8078
self.micro_batch_size = self.train_micro_batch_size_per_gpu()
8179
self.micro_batches = self.gradient_accumulation_steps()
8280

@@ -263,7 +261,7 @@ def _reserve_pipe_buffers(self, num_buffers):
263261
self.pipe_buffers[key].extend([None] * num_added)
264262
self.num_pipe_buffers = num_buffers
265263

266-
def train_batch(self, data_iter=None):
264+
def train_batch(self, data_iter=None, layers_to_hook=None):
267265
"""Progress the pipeline to train the next batch of data. The engine will ingest
268266
``self.train_batch_size()`` total samples collectively across all workers.
269267
@@ -291,9 +289,11 @@ def train_batch(self, data_iter=None):
291289
raise RuntimeError(
292290
f'train_batch() requires gradients enabled. Use eval_batch() instead.')
293291

292+
if layers_to_hook is not None:
293+
self.layers_to_hook = layers_to_hook
294+
294295
if data_iter:
295296
self.set_dataiterator(data_iter)
296-
297297
self.module.train()
298298
self.total_loss = None
299299

@@ -342,9 +342,13 @@ def train_batch(self, data_iter=None):
342342
self.timer_values = timer_values
343343

344344
# TODO: should return precisely what loss returned and allow others to be queried?
345+
346+
if layers_to_hook is not None:
347+
self.layers_to_hook = []
348+
345349
return self.agg_train_loss
346350

347-
def eval_batch(self, data_iter, return_logits=False):
351+
def eval_batch(self, data_iter, return_logits=False, layers_to_hook=None):
348352
"""Evaluate the pipeline on a batch of data from ``data_iter``. The
349353
engine will evaluate ``self.train_batch_size()`` total samples
350354
collectively across all workers.
@@ -375,6 +379,9 @@ def eval_batch(self, data_iter, return_logits=False):
375379
self.module.eval()
376380
self.total_loss = None
377381

382+
if layers_to_hook is not None:
383+
self.layers_to_hook = layers_to_hook
384+
378385
# Use the provided data iterator
379386
train_iterator = self.data_iterator
380387
self.set_dataiterator(data_iter)
@@ -399,6 +406,10 @@ def eval_batch(self, data_iter, return_logits=False):
399406
# Restore the training iterator
400407
self.set_dataiterator(train_iterator)
401408

409+
# reset layers to hook to empty
410+
if layers_to_hook is not None:
411+
self.layers_to_hook = []
412+
402413
# Reset any buffers that may have been populated during the forward passes.
403414
# ds_checkpointing.reset()
404415
self.eval_return_logits = False
@@ -408,7 +419,7 @@ def eval_batch(self, data_iter, return_logits=False):
408419
return self.agg_eval_loss, outputs
409420
return self.agg_eval_loss
410421

411-
def inference_batch(self, data_iter):
422+
def inference_batch(self, data_iter, layers_to_hook=None):
412423
"""Inference the pipeline on a single batch of data from ``data_iter``.
413424
414425
This method is equivalent to:
@@ -450,6 +461,9 @@ def inference_batch(self, data_iter):
450461
train_iterator = self.data_iterator
451462
self.set_dataiterator(data_iter)
452463

464+
if layers_to_hook is not None:
465+
self.layers_to_hook = layers_to_hook
466+
453467
# Do the work
454468
sched = schedule.InferenceSchedule(micro_batches=self.micro_batches,
455469
stages=self.num_stages,
@@ -528,6 +542,10 @@ def inference_batch(self, data_iter):
528542
self.set_dataiterator(train_iterator)
529543
self.set_batch_fn(train_batch_fn)
530544

545+
# reset layers to hook to empty
546+
if layers_to_hook is not None:
547+
self.layers_to_hook = []
548+
531549
return logits, presents
532550

533551
def is_first_stage(self):
@@ -661,7 +679,7 @@ def _exec_forward_pass(self, buffer_id):
661679
self._zero_grads(inputs)
662680

663681
outputs = super().forward(inputs)
664-
682+
665683
# Partition the outputs if we are not the last stage
666684
if self.is_pipe_partitioned and not self.is_last_stage():
667685
part = PartitionedTensor(tensor=outputs[0],

0 commit comments

Comments
 (0)