Skip to content

Commit d9846af

Browse files
committed
refresh fp32 params from fp16 model weights if refreshing from state dict fails
1 parent 95a5f00 commit d9846af

2 files changed

Lines changed: 227 additions & 165 deletions

File tree

deepspeed/runtime/fp16/fused_optimizer.py

Lines changed: 124 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,32 @@
1-
'''
1+
"""
22
Copyright 2019 The Microsoft DeepSpeed Team
33
44
Copyright NVIDIA/apex
55
This file is adapted from FP16_Optimizer in NVIDIA/apex
6-
'''
6+
"""
77

88
import torch
99
import math
1010
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
1111

1212
from deepspeed.runtime.utils import get_grad_norm, CheckOverflow, get_weight_norm
13-
from deepspeed.runtime.fp16.loss_scaler import INITIAL_LOSS_SCALE, SCALE_WINDOW, MIN_LOSS_SCALE
13+
from deepspeed.runtime.fp16.loss_scaler import (
14+
INITIAL_LOSS_SCALE,
15+
SCALE_WINDOW,
16+
MIN_LOSS_SCALE,
17+
)
1418
from deepspeed.utils import logger, log_dist
1519

1620
from ...ops.adam import FusedAdam
21+
1722
FP16_FUSED_SUPPORTED_OPTIMIZERS = [
1823
FusedAdam,
1924
]
2025

2126
# Add apex FusedAdam to supported list if apex is installed
2227
try:
2328
import apex
29+
2430
FP16_FUSED_SUPPORTED_OPTIMIZERS.append(apex.optimizers.FusedAdam)
2531
except ImportError:
2632
pass
@@ -34,34 +40,35 @@ def is_fp16_fused_supported_optimizer(optimizer):
3440
bool: True if ``optimizer`` is compatible with ``FP16_Optimizer``.
3541
"""
3642
from deepspeed.runtime.config import ONEBIT_ADAM_OPTIMIZER
43+
3744
if isinstance(optimizer, tuple(FP16_FUSED_SUPPORTED_OPTIMIZERS)):
3845
return True
3946
if optimizer.__class__.__name__.lower() == ONEBIT_ADAM_OPTIMIZER.lower():
4047
return True
4148
return False
4249

4350

44-
45-
4651
class FP16_Optimizer(object):
4752
"""
4853
FP16 Optimizer for training fp16 models. Handles loss scaling.
4954
5055
For usage example please see, TODO: DeepSpeed V2 Tutorial
5156
"""
5257

53-
def __init__(self,
54-
init_optimizer,
55-
deepspeed=None,
56-
static_loss_scale=1.0,
57-
dynamic_loss_scale=False,
58-
initial_dynamic_scale=2 ** 32,
59-
dynamic_loss_args=None,
60-
verbose=True,
61-
mpu=None,
62-
clip_grad=0.0,
63-
fused_adam_legacy=False,
64-
timers=None):
58+
def __init__(
59+
self,
60+
init_optimizer,
61+
deepspeed=None,
62+
static_loss_scale=1.0,
63+
dynamic_loss_scale=False,
64+
initial_dynamic_scale=2 ** 32,
65+
dynamic_loss_args=None,
66+
verbose=True,
67+
mpu=None,
68+
clip_grad=0.0,
69+
fused_adam_legacy=False,
70+
timers=None,
71+
):
6572

6673
self.fused_adam_legacy = fused_adam_legacy
6774
self.timers = timers
@@ -78,23 +85,28 @@ def __init__(self,
7885
# loop to deal with groups
7986
for i, param_group in enumerate(self.optimizer.param_groups):
8087
# push this group to list before modify
81-
self.fp16_groups.append(param_group['params'])
88+
self.fp16_groups.append(param_group["params"])
8289
# init fp16 weight buffer, flattened
8390
self.fp16_groups_flat.append(
84-
_flatten_dense_tensors([p.clone().detach()
85-
for p in self.fp16_groups[i]]))
91+
_flatten_dense_tensors(
92+
[p.clone().detach() for p in self.fp16_groups[i]]
93+
)
94+
)
8695
# set model fp16 weight to slices of flattened buffer
87-
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
88-
self.fp16_groups[i])
96+
updated_params = _unflatten_dense_tensors(
97+
self.fp16_groups_flat[i], self.fp16_groups[i]
98+
)
8999
for p, q in zip(self.fp16_groups[i], updated_params):
90100
p.data = q.data
91101
# init master weight, flattened
92102
self.fp32_groups_flat.append(
93-
self.fp16_groups_flat[i].clone().float().detach())
103+
self.fp16_groups_flat[i].clone().float().detach()
104+
)
94105
# modify optimizer of have flat master weight
95106
self.fp32_groups_flat[
96-
i].requires_grad = True # keep this in case internal optimizer uses it
97-
param_group['params'] = [self.fp32_groups_flat[i]]
107+
i
108+
].requires_grad = True # keep this in case internal optimizer uses it
109+
param_group["params"] = [self.fp32_groups_flat[i]]
98110

99111
# we may have a way of fusing dynamic scale. Do not support for now
100112
if dynamic_loss_scale:
@@ -120,8 +132,8 @@ def __init__(self,
120132
self.clip_grad = clip_grad
121133
self.norm_type = 2
122134

123-
TORCH_MAJOR = int(torch.__version__.split('.')[0])
124-
TORCH_MINOR = int(torch.__version__.split('.')[1])
135+
TORCH_MAJOR = int(torch.__version__.split(".")[0])
136+
TORCH_MINOR = int(torch.__version__.split(".")[1])
125137
if TORCH_MAJOR == 0 and TORCH_MINOR <= 4:
126138
self.clip_grad_norm = torch.nn.utils.clip_grad_norm
127139
else:
@@ -131,16 +143,16 @@ def __init__(self,
131143
self.mpu = mpu
132144

133145
self.overflow = False
134-
self.overflow_checker = CheckOverflow(self.fp16_groups,
135-
mpu=self.mpu,
136-
deepspeed=deepspeed)
146+
self.overflow_checker = CheckOverflow(
147+
self.fp16_groups, mpu=self.mpu, deepspeed=deepspeed
148+
)
137149
self.initialize_optimizer_states()
138150

139151
def initialize_optimizer_states(self):
140152
for i, group in enumerate(self.fp16_groups):
141153
self.fp32_groups_flat[i].grad = torch.zeros(
142-
self.fp32_groups_flat[i].size(),
143-
device=self.fp32_groups_flat[i].device)
154+
self.fp32_groups_flat[i].size(), device=self.fp32_groups_flat[i].device
155+
)
144156

145157
self.optimizer.step()
146158

@@ -172,12 +184,15 @@ def step_fused_adam(self, closure=None):
172184
norm_groups = []
173185
for i, group in enumerate(self.fp16_groups):
174186
grads_groups_flat.append(
175-
_flatten_dense_tensors([
176-
torch.zeros(p.size(),
177-
dtype=p.dtype,
178-
device=p.device) if p.grad is None else p.grad
179-
for p in group
180-
]))
187+
_flatten_dense_tensors(
188+
[
189+
torch.zeros(p.size(), dtype=p.dtype, device=p.device)
190+
if p.grad is None
191+
else p.grad
192+
for p in group
193+
]
194+
)
195+
)
181196
norm_groups.append(get_weight_norm(grads_groups_flat[i], mpu=self.mpu))
182197

183198
self.overflow = self.overflow_checker.check_using_norm(norm_groups)
@@ -186,23 +201,26 @@ def step_fused_adam(self, closure=None):
186201

187202
if self.overflow:
188203
if self.verbose:
189-
logger.info("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
190-
"scale: {}, reducing to {}".format(
191-
prev_scale,
192-
self.cur_scale))
204+
logger.info(
205+
"[deepspeed] OVERFLOW! Skipping step. Attempted loss "
206+
"scale: {}, reducing to {}".format(prev_scale, self.cur_scale)
207+
)
193208
return self.overflow
194-
combined_scale = self.unscale_and_clip_grads(grads_groups_flat,
195-
norm_groups,
196-
apply_scale=False)
209+
combined_scale = self.unscale_and_clip_grads(
210+
grads_groups_flat, norm_groups, apply_scale=False
211+
)
197212
# norm is in fact norm*cur_scale
198-
self.optimizer.step(grads=[[g] for g in grads_groups_flat],
199-
output_params=[[p] for p in self.fp16_groups_flat],
200-
scale=combined_scale,
201-
grad_norms=norm_groups)
213+
self.optimizer.step(
214+
grads=[[g] for g in grads_groups_flat],
215+
output_params=[[p] for p in self.fp16_groups_flat],
216+
scale=combined_scale,
217+
grad_norms=norm_groups,
218+
)
202219
# TODO: we probably don't need this? just to be safe
203220
for i in range(len(norm_groups)):
204-
updated_params = _unflatten_dense_tensors(self.fp16_groups_flat[i],
205-
self.fp16_groups[i])
221+
updated_params = _unflatten_dense_tensors(
222+
self.fp16_groups_flat[i], self.fp16_groups[i]
223+
)
206224
for p, q in zip(self.fp16_groups[i], updated_params):
207225
p.data = q.data
208226
return self.overflow
@@ -230,11 +248,11 @@ def step(self, closure=None):
230248
return self.step_fused_adam()
231249

232250
COMPUTE_NORM = "compute_norm"
233-
OVERFLOW_CHECK = 'overflow_check'
251+
OVERFLOW_CHECK = "overflow_check"
234252
OVERFLOW_TIMERS = [COMPUTE_NORM, OVERFLOW_CHECK]
235-
UNSCALE_AND_CLIP = 'unscale_and_clip'
236-
BASIC_STEP = 'basic_step'
237-
UPDATE_FP16 = 'update_fp16'
253+
UNSCALE_AND_CLIP = "unscale_and_clip"
254+
BASIC_STEP = "basic_step"
255+
UPDATE_FP16 = "update_fp16"
238256
STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP, BASIC_STEP, UPDATE_FP16]
239257

240258
# First determine if there is overflow.
@@ -252,7 +270,8 @@ def step(self, closure=None):
252270
log_dist(
253271
"Overflow detected. Skipping step. Attempted loss "
254272
f"scale: {prev_scale}, reducing to {self.cur_scale}",
255-
ranks=[0])
273+
ranks=[0],
274+
)
256275
# Clear gradients
257276
for i, group in enumerate(self.fp16_groups):
258277
for p in group:
@@ -266,12 +285,15 @@ def step(self, closure=None):
266285
data_type = self.fp32_groups_flat[i].dtype
267286

268287
grads_groups_flat.append(
269-
_flatten_dense_tensors([
270-
torch.zeros(p.size(),
271-
dtype=data_type,
272-
device=p.device)
273-
if p.grad is None else p.grad.to(data_type) for p in group
274-
]))
288+
_flatten_dense_tensors(
289+
[
290+
torch.zeros(p.size(), dtype=data_type, device=p.device)
291+
if p.grad is None
292+
else p.grad.to(data_type)
293+
for p in group
294+
]
295+
)
296+
)
275297

276298
for p in group:
277299
p.grad = None
@@ -296,8 +318,9 @@ def step(self, closure=None):
296318

297319
self.start_timers([UPDATE_FP16])
298320
for i in range(len(self.fp16_groups)):
299-
updated_params = _unflatten_dense_tensors(self.fp32_groups_flat[i],
300-
self.fp16_groups[i])
321+
updated_params = _unflatten_dense_tensors(
322+
self.fp32_groups_flat[i], self.fp16_groups[i]
323+
)
301324
for p, q in zip(self.fp16_groups[i], updated_params):
302325
p.data.copy_(q.data)
303326
self.stop_timers([UPDATE_FP16])
@@ -314,15 +337,15 @@ def unscale_and_clip_grads(self, grad_groups_flat, norm_groups, apply_scale=True
314337

315338
# compute combined scale factor for this group
316339
combined_scale = self.cur_scale
317-
if self.clip_grad > 0.:
340+
if self.clip_grad > 0.0:
318341
# norm is in fact norm*scale
319342
clip = ((total_norm / self.cur_scale) + 1e-6) / self.clip_grad
320343
if clip > 1:
321344
combined_scale = clip * self.cur_scale
322345

323346
if apply_scale:
324347
for grad in grad_groups_flat:
325-
grad.data.mul_(1. / combined_scale)
348+
grad.data.mul_(1.0 / combined_scale)
326349

327350
return combined_scale
328351

@@ -341,8 +364,9 @@ def _update_scale(self, skip):
341364
if self.dynamic_loss_scale:
342365
prev_scale = self.cur_scale
343366
if skip:
344-
self.cur_scale = max(self.cur_scale / self.scale_factor,
345-
self.min_loss_scale)
367+
self.cur_scale = max(
368+
self.cur_scale / self.scale_factor, self.min_loss_scale
369+
)
346370
self.last_overflow_iter = self.cur_iter
347371
if self.verbose:
348372
logger.info(f"\nGrad overflow on iteration {self.cur_iter}")
@@ -356,7 +380,8 @@ def _update_scale(self, skip):
356380
self.cur_scale *= self.scale_factor
357381
if self.verbose:
358382
logger.info(
359-
f"No Grad overflow for {self.scale_window} iterations")
383+
f"No Grad overflow for {self.scale_window} iterations"
384+
)
360385
logger.info(
361386
f"Increasing dynamic loss scale from {prev_scale} to {self.cur_scale}"
362387
)
@@ -398,16 +423,16 @@ def state_dict(self):
398423
torch.save(checkpoint, "saved.pth")
399424
"""
400425
state_dict = {}
401-
state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
402-
state_dict['cur_scale'] = self.cur_scale
403-
state_dict['cur_iter'] = self.cur_iter
404-
if state_dict['dynamic_loss_scale']:
405-
state_dict['last_overflow_iter'] = self.last_overflow_iter
406-
state_dict['scale_factor'] = self.scale_factor
407-
state_dict['scale_window'] = self.scale_window
408-
state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
409-
state_dict['fp32_groups_flat'] = self.fp32_groups_flat
410-
state_dict['clip_grad'] = self.clip_grad
426+
state_dict["dynamic_loss_scale"] = self.dynamic_loss_scale
427+
state_dict["cur_scale"] = self.cur_scale
428+
state_dict["cur_iter"] = self.cur_iter
429+
if state_dict["dynamic_loss_scale"]:
430+
state_dict["last_overflow_iter"] = self.last_overflow_iter
431+
state_dict["scale_factor"] = self.scale_factor
432+
state_dict["scale_window"] = self.scale_window
433+
state_dict["optimizer_state_dict"] = self.optimizer.state_dict()
434+
state_dict["fp32_groups_flat"] = self.fp32_groups_flat
435+
state_dict["clip_grad"] = self.clip_grad
411436
return state_dict
412437

413438
# Refresh fp32 master params from fp16 copies
@@ -432,16 +457,16 @@ def load_state_dict(self, state_dict, load_optimizer_states=True):
432457
optimizer.load_state_dict(checkpoint['optimizer'])
433458
"""
434459
# I think it should actually be ok to reload the optimizer before the model.
435-
self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
436-
self.cur_scale = state_dict['cur_scale']
437-
self.cur_iter = state_dict['cur_iter']
438-
if state_dict['dynamic_loss_scale']:
439-
self.last_overflow_iter = state_dict['last_overflow_iter']
440-
self.scale_factor = state_dict['scale_factor']
441-
self.scale_window = state_dict['scale_window']
460+
self.dynamic_loss_scale = state_dict["dynamic_loss_scale"]
461+
self.cur_scale = state_dict["cur_scale"]
462+
self.cur_iter = state_dict["cur_iter"]
463+
if state_dict["dynamic_loss_scale"]:
464+
self.last_overflow_iter = state_dict["last_overflow_iter"]
465+
self.scale_factor = state_dict["scale_factor"]
466+
self.scale_window = state_dict["scale_window"]
442467
if load_optimizer_states:
443-
self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
444-
self.clip_grad = state_dict['clip_grad']
468+
self.optimizer.load_state_dict(state_dict["optimizer_state_dict"])
469+
self.clip_grad = state_dict["clip_grad"]
445470
# At this point, the optimizer's references to the model's fp32 parameters are up to date.
446471
# The optimizer's hyperparameters and internal buffers are also up to date.
447472
# However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
@@ -456,8 +481,17 @@ def load_state_dict(self, state_dict, load_optimizer_states=True):
456481
# the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
457482
# constructed in the same way as the one whose state_dict we are loading, the same master params
458483
# are guaranteed to exist, so we can just copy_() from the saved master params.
459-
for current, saved in zip(self.fp32_groups_flat, state_dict['fp32_groups_flat']):
460-
current.data.copy_(saved.data)
484+
try:
485+
for current, saved in zip(
486+
self.fp32_groups_flat, state_dict["fp32_groups_flat"]
487+
):
488+
current.data.copy_(saved.data)
489+
except RuntimeError as error:
490+
print(error)
491+
print(
492+
"Error in loading fp32 model parameters!\nRefreshing fp32 model params from the model's fp16 params instead. This may incur some precision loss."
493+
)
494+
self.refresh_fp32_params()
461495

462496
def __repr__(self):
463497
return repr(self.optimizer)

0 commit comments

Comments
 (0)