Skip to content

Commit eb7f5cf

Browse files
committed
fix bug with inputs requiring gradient
1 parent d9846af commit eb7f5cf

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

deepspeed/runtime/pipe/engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,8 @@ def _exec_load_micro_batch(self, buffer_id):
813813
for x in batch[1]:
814814
assert torch.is_tensor(x)
815815
x = x.to(self.device).detach()
816+
if self.is_first_stage(): # first stage can also be last stage - need to ensure the inputs require grad
817+
x.requires_grad = x.is_floating_point()
816818
loaded.append(x)
817819
loaded = tuple(loaded)
818820

0 commit comments

Comments
 (0)