Skip to content

Commit 69a3c6c

Browse files
authored
Merge pull request #1 from igor0/igor
Igor
2 parents eb7f5cf + 3987139 commit 69a3c6c

2 files changed

Lines changed: 8 additions & 1 deletion

File tree

deepspeed/runtime/engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,6 +1596,7 @@ def _load_checkpoint(self,
15961596
def _load_zero_checkpoint(self, load_dir, tag, load_optimizer_states=True):
15971597
zero_sd_list = self._get_all_zero_checkpoints(load_dir, tag)
15981598
if zero_sd_list is None:
1599+
self.optimizer._restore_from_fp16_weights()
15991600
return
16001601

16011602
self.optimizer.load_state_dict(

deepspeed/runtime/fp16/fused_optimizer.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,12 @@ def load_state_dict(self, state_dict, load_optimizer_states=True):
456456
model.load_state_dict(checkpoint['model'])
457457
optimizer.load_state_dict(checkpoint['optimizer'])
458458
"""
459+
460+
if state_dict is None or "dynamic_loss_scale" not in state_dict:
461+
state_dict = self.state_dict()
462+
self.refresh_fp32_params()
463+
return
464+
459465
# I think it should actually be ok to reload the optimizer before the model.
460466
self.dynamic_loss_scale = state_dict["dynamic_loss_scale"]
461467
self.cur_scale = state_dict["cur_scale"]
@@ -486,7 +492,7 @@ def load_state_dict(self, state_dict, load_optimizer_states=True):
486492
self.fp32_groups_flat, state_dict["fp32_groups_flat"]
487493
):
488494
current.data.copy_(saved.data)
489-
except RuntimeError as error:
495+
except (RuntimeError, KeyError) as error:
490496
print(error)
491497
print(
492498
"Error in loading fp32 model parameters!\nRefreshing fp32 model params from the model's fp16 params instead. This may incur some precision loss."

0 commit comments

Comments
 (0)