We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent d9846af commit eb7f5cfCopy full SHA for eb7f5cf
1 file changed
deepspeed/runtime/pipe/engine.py
@@ -813,6 +813,8 @@ def _exec_load_micro_batch(self, buffer_id):
813
for x in batch[1]:
814
assert torch.is_tensor(x)
815
x = x.to(self.device).detach()
816
+ if self.is_first_stage(): # first stage can also be last stage - need to ensure the inputs require grad
817
+ x.requires_grad = x.is_floating_point()
818
loaded.append(x)
819
loaded = tuple(loaded)
820
0 commit comments