Skip to content

Commit a05738e

Browse files
committed
Fix fine-tuning of FP16_Optimizer
1 parent eb7f5cf commit a05738e

1 file changed

Lines changed: 6 additions & 0 deletions

File tree

deepspeed/runtime/fp16/fused_optimizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,12 @@ def load_state_dict(self, state_dict, load_optimizer_states=True):
456456
model.load_state_dict(checkpoint['model'])
457457
optimizer.load_state_dict(checkpoint['optimizer'])
458458
"""
459+
460+
if state_dict is None:
461+
state_dict = self.state_dict()
462+
self.refresh_fp32_params()
463+
return
464+
459465
# I think it should actually be ok to reload the optimizer before the model.
460466
self.dynamic_loss_scale = state_dict["dynamic_loss_scale"]
461467
self.cur_scale = state_dict["cur_scale"]

0 commit comments

Comments
 (0)