@@ -65,6 +65,8 @@ def __init__(self, *super_args, **super_kwargs):
6565
6666 # We schedule the all-reduces, so disable it in super().backward()
6767 self .enable_backward_allreduce = False
68+ self .eval_return_logits = False
69+ self .outputs = None
6870
6971 # used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB
7072 self .pipeline_enable_backward_allreduce = True
@@ -336,7 +338,7 @@ def train_batch(self, data_iter=None):
336338 # TODO: should return precisely what loss returned and allow others to be queried?
337339 return self .agg_train_loss
338340
339- def eval_batch (self , data_iter ):
341+ def eval_batch (self , data_iter , return_logits = False ):
340342 """Evaluate the pipeline on a batch of data from ``data_iter``. The
341343 engine will evaluate ``self.train_batch_size()`` total samples
342344 collectively across all workers.
@@ -363,7 +365,7 @@ def eval_batch(self, data_iter):
363365 Returns:
364366 The arithmetic mean of the losses computed this batch.
365367 """
366-
368+ self . eval_return_logits = return_logits
367369 self .module .eval ()
368370 self .total_loss = None
369371
@@ -393,7 +395,11 @@ def eval_batch(self, data_iter):
393395
394396 # Reset any buffers that may have been populated during the forward passes.
395397 # ds_checkpointing.reset()
396-
398+ self .eval_return_logits = False
399+ if return_logits :
400+ outputs = self .outputs
401+ self .outputs = None
402+ return self .agg_eval_loss , outputs
397403 return self .agg_eval_loss
398404
399405 def inference_batch (self , data_iter ):
@@ -666,6 +672,8 @@ def _exec_forward_pass(self, buffer_id):
666672 else :
667673 # Some models just return loss from forward()
668674 self .loss = outputs
675+ if self .eval_return_logits :
676+ self .outputs = outputs
669677
670678 if isinstance (self .loss , torch .Tensor ):
671679 if self .total_loss is None :
0 commit comments