2929LOG_STAGE = - 2
3030DATA_PARALLEL_ID = - 2
3131
32-
3332def is_even (number ):
3433 return number % 2 == 0
3534
@@ -76,7 +75,6 @@ def __init__(self, *super_args, **super_kwargs):
7675
7776 # pipeline step for logging
7877 self .log_batch_step_id = - 1
79-
8078 self .micro_batch_size = self .train_micro_batch_size_per_gpu ()
8179 self .micro_batches = self .gradient_accumulation_steps ()
8280
@@ -263,7 +261,7 @@ def _reserve_pipe_buffers(self, num_buffers):
263261 self .pipe_buffers [key ].extend ([None ] * num_added )
264262 self .num_pipe_buffers = num_buffers
265263
266- def train_batch (self , data_iter = None ):
264+ def train_batch (self , data_iter = None , layers_to_hook = None ):
267265 """Progress the pipeline to train the next batch of data. The engine will ingest
268266 ``self.train_batch_size()`` total samples collectively across all workers.
269267
@@ -291,9 +289,11 @@ def train_batch(self, data_iter=None):
291289 raise RuntimeError (
292290 f'train_batch() requires gradients enabled. Use eval_batch() instead.' )
293291
292+ if layers_to_hook is not None :
293+ self .layers_to_hook = layers_to_hook
294+
294295 if data_iter :
295296 self .set_dataiterator (data_iter )
296-
297297 self .module .train ()
298298 self .total_loss = None
299299
@@ -342,9 +342,13 @@ def train_batch(self, data_iter=None):
342342 self .timer_values = timer_values
343343
344344 # TODO: should return precisely what loss returned and allow others to be queried?
345+
346+ if layers_to_hook is not None :
347+ self .layers_to_hook = []
348+
345349 return self .agg_train_loss
346350
347- def eval_batch (self , data_iter , return_logits = False ):
351+ def eval_batch (self , data_iter , return_logits = False , layers_to_hook = None ):
348352 """Evaluate the pipeline on a batch of data from ``data_iter``. The
349353 engine will evaluate ``self.train_batch_size()`` total samples
350354 collectively across all workers.
@@ -375,6 +379,9 @@ def eval_batch(self, data_iter, return_logits=False):
375379 self .module .eval ()
376380 self .total_loss = None
377381
382+ if layers_to_hook is not None :
383+ self .layers_to_hook = layers_to_hook
384+
378385 # Use the provided data iterator
379386 train_iterator = self .data_iterator
380387 self .set_dataiterator (data_iter )
@@ -399,6 +406,10 @@ def eval_batch(self, data_iter, return_logits=False):
399406 # Restore the training iterator
400407 self .set_dataiterator (train_iterator )
401408
409+ # reset layers to hook to empty
410+ if layers_to_hook is not None :
411+ self .layers_to_hook = []
412+
402413 # Reset any buffers that may have been populated during the forward passes.
403414 # ds_checkpointing.reset()
404415 self .eval_return_logits = False
@@ -408,7 +419,7 @@ def eval_batch(self, data_iter, return_logits=False):
408419 return self .agg_eval_loss , outputs
409420 return self .agg_eval_loss
410421
411- def inference_batch (self , data_iter ):
422+ def inference_batch (self , data_iter , layers_to_hook = None ):
412423 """Inference the pipeline on a single batch of data from ``data_iter``.
413424
414425 This method is equivalent to:
@@ -450,6 +461,9 @@ def inference_batch(self, data_iter):
450461 train_iterator = self .data_iterator
451462 self .set_dataiterator (data_iter )
452463
464+ if layers_to_hook is not None :
465+ self .layers_to_hook = layers_to_hook
466+
453467 # Do the work
454468 sched = schedule .InferenceSchedule (micro_batches = self .micro_batches ,
455469 stages = self .num_stages ,
@@ -528,6 +542,10 @@ def inference_batch(self, data_iter):
528542 self .set_dataiterator (train_iterator )
529543 self .set_batch_fn (train_batch_fn )
530544
545+ # reset layers to hook to empty
546+ if layers_to_hook is not None :
547+ self .layers_to_hook = []
548+
531549 return logits , presents
532550
533551 def is_first_stage (self ):
@@ -661,7 +679,7 @@ def _exec_forward_pass(self, buffer_id):
661679 self ._zero_grads (inputs )
662680
663681 outputs = super ().forward (inputs )
664-
682+
665683 # Partition the outputs if we are not the last stage
666684 if self .is_pipe_partitioned and not self .is_last_stage ():
667685 part = PartitionedTensor (tensor = outputs [0 ],
0 commit comments