Skip to content

Commit 3389e4f

Browse files
committed
add ability to return logits in eval
1 parent c7c2063 commit 3389e4f

1 file changed

Lines changed: 11 additions & 3 deletions

File tree

deepspeed/runtime/pipe/engine.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def __init__(self, *super_args, **super_kwargs):
6565

6666
# We schedule the all-reduces, so disable it in super().backward()
6767
self.enable_backward_allreduce = False
68+
self.eval_return_logits = False
69+
self.outputs = None
6870

6971
# used to disable the pipeline all-reduce when used with 1-bit Adam/1-bit LAMB
7072
self.pipeline_enable_backward_allreduce = True
@@ -336,7 +338,7 @@ def train_batch(self, data_iter=None):
336338
# TODO: should return precisely what loss returned and allow others to be queried?
337339
return self.agg_train_loss
338340

339-
def eval_batch(self, data_iter):
341+
def eval_batch(self, data_iter, return_logits=False):
340342
"""Evaluate the pipeline on a batch of data from ``data_iter``. The
341343
engine will evaluate ``self.train_batch_size()`` total samples
342344
collectively across all workers.
@@ -363,7 +365,7 @@ def eval_batch(self, data_iter):
363365
Returns:
364366
The arithmetic mean of the losses computed this batch.
365367
"""
366-
368+
self.eval_return_logits = return_logits
367369
self.module.eval()
368370
self.total_loss = None
369371

@@ -393,7 +395,11 @@ def eval_batch(self, data_iter):
393395

394396
# Reset any buffers that may have been populated during the forward passes.
395397
# ds_checkpointing.reset()
396-
398+
self.eval_return_logits = False
399+
if return_logits:
400+
outputs = self.outputs
401+
self.outputs = None
402+
return self.agg_eval_loss, outputs
397403
return self.agg_eval_loss
398404

399405
def inference_batch(self, data_iter):
@@ -666,6 +672,8 @@ def _exec_forward_pass(self, buffer_id):
666672
else:
667673
# Some models just return loss from forward()
668674
self.loss = outputs
675+
if self.eval_return_logits:
676+
self.outputs = outputs
669677

670678
if isinstance(self.loss, torch.Tensor):
671679
if self.total_loss is None:

0 commit comments

Comments
 (0)