1- '''
1+ """
22Copyright 2019 The Microsoft DeepSpeed Team
33
44Copyright NVIDIA/apex
55This file is adapted from FP16_Optimizer in NVIDIA/apex
6- '''
6+ """
77
88import torch
99import math
1010from torch ._utils import _flatten_dense_tensors , _unflatten_dense_tensors
1111
1212from deepspeed .runtime .utils import get_grad_norm , CheckOverflow , get_weight_norm
13- from deepspeed .runtime .fp16 .loss_scaler import INITIAL_LOSS_SCALE , SCALE_WINDOW , MIN_LOSS_SCALE
13+ from deepspeed .runtime .fp16 .loss_scaler import (
14+ INITIAL_LOSS_SCALE ,
15+ SCALE_WINDOW ,
16+ MIN_LOSS_SCALE ,
17+ )
1418from deepspeed .utils import logger , log_dist
1519
1620from ...ops .adam import FusedAdam
21+
1722FP16_FUSED_SUPPORTED_OPTIMIZERS = [
1823 FusedAdam ,
1924]
2025
2126# Add apex FusedAdam to supported list if apex is installed
2227try :
2328 import apex
29+
2430 FP16_FUSED_SUPPORTED_OPTIMIZERS .append (apex .optimizers .FusedAdam )
2531except ImportError :
2632 pass
@@ -34,34 +40,35 @@ def is_fp16_fused_supported_optimizer(optimizer):
3440 bool: True if ``optimizer`` is compatible with ``FP16_Optimizer``.
3541 """
3642 from deepspeed .runtime .config import ONEBIT_ADAM_OPTIMIZER
43+
3744 if isinstance (optimizer , tuple (FP16_FUSED_SUPPORTED_OPTIMIZERS )):
3845 return True
3946 if optimizer .__class__ .__name__ .lower () == ONEBIT_ADAM_OPTIMIZER .lower ():
4047 return True
4148 return False
4249
4350
44-
45-
4651class FP16_Optimizer (object ):
4752 """
4853 FP16 Optimizer for training fp16 models. Handles loss scaling.
4954
5055 For usage example please see, TODO: DeepSpeed V2 Tutorial
5156 """
5257
53- def __init__ (self ,
54- init_optimizer ,
55- deepspeed = None ,
56- static_loss_scale = 1.0 ,
57- dynamic_loss_scale = False ,
58- initial_dynamic_scale = 2 ** 32 ,
59- dynamic_loss_args = None ,
60- verbose = True ,
61- mpu = None ,
62- clip_grad = 0.0 ,
63- fused_adam_legacy = False ,
64- timers = None ):
58+ def __init__ (
59+ self ,
60+ init_optimizer ,
61+ deepspeed = None ,
62+ static_loss_scale = 1.0 ,
63+ dynamic_loss_scale = False ,
64+ initial_dynamic_scale = 2 ** 32 ,
65+ dynamic_loss_args = None ,
66+ verbose = True ,
67+ mpu = None ,
68+ clip_grad = 0.0 ,
69+ fused_adam_legacy = False ,
70+ timers = None ,
71+ ):
6572
6673 self .fused_adam_legacy = fused_adam_legacy
6774 self .timers = timers
@@ -78,23 +85,28 @@ def __init__(self,
7885 # loop to deal with groups
7986 for i , param_group in enumerate (self .optimizer .param_groups ):
8087 # push this group to list before modify
81- self .fp16_groups .append (param_group [' params' ])
88+ self .fp16_groups .append (param_group [" params" ])
8289 # init fp16 weight buffer, flattened
8390 self .fp16_groups_flat .append (
84- _flatten_dense_tensors ([p .clone ().detach ()
85- for p in self .fp16_groups [i ]]))
91+ _flatten_dense_tensors (
92+ [p .clone ().detach () for p in self .fp16_groups [i ]]
93+ )
94+ )
8695 # set model fp16 weight to slices of flattened buffer
87- updated_params = _unflatten_dense_tensors (self .fp16_groups_flat [i ],
88- self .fp16_groups [i ])
96+ updated_params = _unflatten_dense_tensors (
97+ self .fp16_groups_flat [i ], self .fp16_groups [i ]
98+ )
8999 for p , q in zip (self .fp16_groups [i ], updated_params ):
90100 p .data = q .data
91101 # init master weight, flattened
92102 self .fp32_groups_flat .append (
93- self .fp16_groups_flat [i ].clone ().float ().detach ())
103+ self .fp16_groups_flat [i ].clone ().float ().detach ()
104+ )
94105 # modify optimizer of have flat master weight
95106 self .fp32_groups_flat [
96- i ].requires_grad = True # keep this in case internal optimizer uses it
97- param_group ['params' ] = [self .fp32_groups_flat [i ]]
107+ i
108+ ].requires_grad = True # keep this in case internal optimizer uses it
109+ param_group ["params" ] = [self .fp32_groups_flat [i ]]
98110
99111 # we may have a way of fusing dynamic scale. Do not support for now
100112 if dynamic_loss_scale :
@@ -120,8 +132,8 @@ def __init__(self,
120132 self .clip_grad = clip_grad
121133 self .norm_type = 2
122134
123- TORCH_MAJOR = int (torch .__version__ .split ('.' )[0 ])
124- TORCH_MINOR = int (torch .__version__ .split ('.' )[1 ])
135+ TORCH_MAJOR = int (torch .__version__ .split ("." )[0 ])
136+ TORCH_MINOR = int (torch .__version__ .split ("." )[1 ])
125137 if TORCH_MAJOR == 0 and TORCH_MINOR <= 4 :
126138 self .clip_grad_norm = torch .nn .utils .clip_grad_norm
127139 else :
@@ -131,16 +143,16 @@ def __init__(self,
131143 self .mpu = mpu
132144
133145 self .overflow = False
134- self .overflow_checker = CheckOverflow (self . fp16_groups ,
135- mpu = self .mpu ,
136- deepspeed = deepspeed )
146+ self .overflow_checker = CheckOverflow (
147+ self . fp16_groups , mpu = self .mpu , deepspeed = deepspeed
148+ )
137149 self .initialize_optimizer_states ()
138150
139151 def initialize_optimizer_states (self ):
140152 for i , group in enumerate (self .fp16_groups ):
141153 self .fp32_groups_flat [i ].grad = torch .zeros (
142- self .fp32_groups_flat [i ].size (),
143- device = self . fp32_groups_flat [ i ]. device )
154+ self .fp32_groups_flat [i ].size (), device = self . fp32_groups_flat [ i ]. device
155+ )
144156
145157 self .optimizer .step ()
146158
@@ -172,12 +184,15 @@ def step_fused_adam(self, closure=None):
172184 norm_groups = []
173185 for i , group in enumerate (self .fp16_groups ):
174186 grads_groups_flat .append (
175- _flatten_dense_tensors ([
176- torch .zeros (p .size (),
177- dtype = p .dtype ,
178- device = p .device ) if p .grad is None else p .grad
179- for p in group
180- ]))
187+ _flatten_dense_tensors (
188+ [
189+ torch .zeros (p .size (), dtype = p .dtype , device = p .device )
190+ if p .grad is None
191+ else p .grad
192+ for p in group
193+ ]
194+ )
195+ )
181196 norm_groups .append (get_weight_norm (grads_groups_flat [i ], mpu = self .mpu ))
182197
183198 self .overflow = self .overflow_checker .check_using_norm (norm_groups )
@@ -186,23 +201,26 @@ def step_fused_adam(self, closure=None):
186201
187202 if self .overflow :
188203 if self .verbose :
189- logger .info ("[deepspeed] OVERFLOW! Skipping step. Attempted loss "
190- "scale: {}, reducing to {}" . format (
191- prev_scale ,
192- self . cur_scale ) )
204+ logger .info (
205+ "[deepspeed] OVERFLOW! Skipping step. Attempted loss "
206+ "scale: {}, reducing to {}" . format ( prev_scale , self . cur_scale )
207+ )
193208 return self .overflow
194- combined_scale = self .unscale_and_clip_grads (grads_groups_flat ,
195- norm_groups ,
196- apply_scale = False )
209+ combined_scale = self .unscale_and_clip_grads (
210+ grads_groups_flat , norm_groups , apply_scale = False
211+ )
197212 # norm is in fact norm*cur_scale
198- self .optimizer .step (grads = [[g ] for g in grads_groups_flat ],
199- output_params = [[p ] for p in self .fp16_groups_flat ],
200- scale = combined_scale ,
201- grad_norms = norm_groups )
213+ self .optimizer .step (
214+ grads = [[g ] for g in grads_groups_flat ],
215+ output_params = [[p ] for p in self .fp16_groups_flat ],
216+ scale = combined_scale ,
217+ grad_norms = norm_groups ,
218+ )
202219 # TODO: we probably don't need this? just to be safe
203220 for i in range (len (norm_groups )):
204- updated_params = _unflatten_dense_tensors (self .fp16_groups_flat [i ],
205- self .fp16_groups [i ])
221+ updated_params = _unflatten_dense_tensors (
222+ self .fp16_groups_flat [i ], self .fp16_groups [i ]
223+ )
206224 for p , q in zip (self .fp16_groups [i ], updated_params ):
207225 p .data = q .data
208226 return self .overflow
@@ -230,11 +248,11 @@ def step(self, closure=None):
230248 return self .step_fused_adam ()
231249
232250 COMPUTE_NORM = "compute_norm"
233- OVERFLOW_CHECK = ' overflow_check'
251+ OVERFLOW_CHECK = " overflow_check"
234252 OVERFLOW_TIMERS = [COMPUTE_NORM , OVERFLOW_CHECK ]
235- UNSCALE_AND_CLIP = ' unscale_and_clip'
236- BASIC_STEP = ' basic_step'
237- UPDATE_FP16 = ' update_fp16'
253+ UNSCALE_AND_CLIP = " unscale_and_clip"
254+ BASIC_STEP = " basic_step"
255+ UPDATE_FP16 = " update_fp16"
238256 STEP_TIMERS = OVERFLOW_TIMERS + [UNSCALE_AND_CLIP , BASIC_STEP , UPDATE_FP16 ]
239257
240258 # First determine if there is overflow.
@@ -252,7 +270,8 @@ def step(self, closure=None):
252270 log_dist (
253271 "Overflow detected. Skipping step. Attempted loss "
254272 f"scale: { prev_scale } , reducing to { self .cur_scale } " ,
255- ranks = [0 ])
273+ ranks = [0 ],
274+ )
256275 # Clear gradients
257276 for i , group in enumerate (self .fp16_groups ):
258277 for p in group :
@@ -266,12 +285,15 @@ def step(self, closure=None):
266285 data_type = self .fp32_groups_flat [i ].dtype
267286
268287 grads_groups_flat .append (
269- _flatten_dense_tensors ([
270- torch .zeros (p .size (),
271- dtype = data_type ,
272- device = p .device )
273- if p .grad is None else p .grad .to (data_type ) for p in group
274- ]))
288+ _flatten_dense_tensors (
289+ [
290+ torch .zeros (p .size (), dtype = data_type , device = p .device )
291+ if p .grad is None
292+ else p .grad .to (data_type )
293+ for p in group
294+ ]
295+ )
296+ )
275297
276298 for p in group :
277299 p .grad = None
@@ -296,8 +318,9 @@ def step(self, closure=None):
296318
297319 self .start_timers ([UPDATE_FP16 ])
298320 for i in range (len (self .fp16_groups )):
299- updated_params = _unflatten_dense_tensors (self .fp32_groups_flat [i ],
300- self .fp16_groups [i ])
321+ updated_params = _unflatten_dense_tensors (
322+ self .fp32_groups_flat [i ], self .fp16_groups [i ]
323+ )
301324 for p , q in zip (self .fp16_groups [i ], updated_params ):
302325 p .data .copy_ (q .data )
303326 self .stop_timers ([UPDATE_FP16 ])
@@ -314,15 +337,15 @@ def unscale_and_clip_grads(self, grad_groups_flat, norm_groups, apply_scale=True
314337
315338 # compute combined scale factor for this group
316339 combined_scale = self .cur_scale
317- if self .clip_grad > 0. :
340+ if self .clip_grad > 0.0 :
318341 # norm is in fact norm*scale
319342 clip = ((total_norm / self .cur_scale ) + 1e-6 ) / self .clip_grad
320343 if clip > 1 :
321344 combined_scale = clip * self .cur_scale
322345
323346 if apply_scale :
324347 for grad in grad_groups_flat :
325- grad .data .mul_ (1. / combined_scale )
348+ grad .data .mul_ (1.0 / combined_scale )
326349
327350 return combined_scale
328351
@@ -341,8 +364,9 @@ def _update_scale(self, skip):
341364 if self .dynamic_loss_scale :
342365 prev_scale = self .cur_scale
343366 if skip :
344- self .cur_scale = max (self .cur_scale / self .scale_factor ,
345- self .min_loss_scale )
367+ self .cur_scale = max (
368+ self .cur_scale / self .scale_factor , self .min_loss_scale
369+ )
346370 self .last_overflow_iter = self .cur_iter
347371 if self .verbose :
348372 logger .info (f"\n Grad overflow on iteration { self .cur_iter } " )
@@ -356,7 +380,8 @@ def _update_scale(self, skip):
356380 self .cur_scale *= self .scale_factor
357381 if self .verbose :
358382 logger .info (
359- f"No Grad overflow for { self .scale_window } iterations" )
383+ f"No Grad overflow for { self .scale_window } iterations"
384+ )
360385 logger .info (
361386 f"Increasing dynamic loss scale from { prev_scale } to { self .cur_scale } "
362387 )
@@ -398,16 +423,16 @@ def state_dict(self):
398423 torch.save(checkpoint, "saved.pth")
399424 """
400425 state_dict = {}
401- state_dict [' dynamic_loss_scale' ] = self .dynamic_loss_scale
402- state_dict [' cur_scale' ] = self .cur_scale
403- state_dict [' cur_iter' ] = self .cur_iter
404- if state_dict [' dynamic_loss_scale' ]:
405- state_dict [' last_overflow_iter' ] = self .last_overflow_iter
406- state_dict [' scale_factor' ] = self .scale_factor
407- state_dict [' scale_window' ] = self .scale_window
408- state_dict [' optimizer_state_dict' ] = self .optimizer .state_dict ()
409- state_dict [' fp32_groups_flat' ] = self .fp32_groups_flat
410- state_dict [' clip_grad' ] = self .clip_grad
426+ state_dict [" dynamic_loss_scale" ] = self .dynamic_loss_scale
427+ state_dict [" cur_scale" ] = self .cur_scale
428+ state_dict [" cur_iter" ] = self .cur_iter
429+ if state_dict [" dynamic_loss_scale" ]:
430+ state_dict [" last_overflow_iter" ] = self .last_overflow_iter
431+ state_dict [" scale_factor" ] = self .scale_factor
432+ state_dict [" scale_window" ] = self .scale_window
433+ state_dict [" optimizer_state_dict" ] = self .optimizer .state_dict ()
434+ state_dict [" fp32_groups_flat" ] = self .fp32_groups_flat
435+ state_dict [" clip_grad" ] = self .clip_grad
411436 return state_dict
412437
413438 # Refresh fp32 master params from fp16 copies
@@ -432,16 +457,16 @@ def load_state_dict(self, state_dict, load_optimizer_states=True):
432457 optimizer.load_state_dict(checkpoint['optimizer'])
433458 """
434459 # I think it should actually be ok to reload the optimizer before the model.
435- self .dynamic_loss_scale = state_dict [' dynamic_loss_scale' ]
436- self .cur_scale = state_dict [' cur_scale' ]
437- self .cur_iter = state_dict [' cur_iter' ]
438- if state_dict [' dynamic_loss_scale' ]:
439- self .last_overflow_iter = state_dict [' last_overflow_iter' ]
440- self .scale_factor = state_dict [' scale_factor' ]
441- self .scale_window = state_dict [' scale_window' ]
460+ self .dynamic_loss_scale = state_dict [" dynamic_loss_scale" ]
461+ self .cur_scale = state_dict [" cur_scale" ]
462+ self .cur_iter = state_dict [" cur_iter" ]
463+ if state_dict [" dynamic_loss_scale" ]:
464+ self .last_overflow_iter = state_dict [" last_overflow_iter" ]
465+ self .scale_factor = state_dict [" scale_factor" ]
466+ self .scale_window = state_dict [" scale_window" ]
442467 if load_optimizer_states :
443- self .optimizer .load_state_dict (state_dict [' optimizer_state_dict' ])
444- self .clip_grad = state_dict [' clip_grad' ]
468+ self .optimizer .load_state_dict (state_dict [" optimizer_state_dict" ])
469+ self .clip_grad = state_dict [" clip_grad" ]
445470 # At this point, the optimizer's references to the model's fp32 parameters are up to date.
446471 # The optimizer's hyperparameters and internal buffers are also up to date.
447472 # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
@@ -456,8 +481,17 @@ def load_state_dict(self, state_dict, load_optimizer_states=True):
456481 # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been
457482 # constructed in the same way as the one whose state_dict we are loading, the same master params
458483 # are guaranteed to exist, so we can just copy_() from the saved master params.
459- for current , saved in zip (self .fp32_groups_flat , state_dict ['fp32_groups_flat' ]):
460- current .data .copy_ (saved .data )
484+ try :
485+ for current , saved in zip (
486+ self .fp32_groups_flat , state_dict ["fp32_groups_flat" ]
487+ ):
488+ current .data .copy_ (saved .data )
489+ except RuntimeError as error :
490+ print (error )
491+ print (
492+ "Error in loading fp32 model parameters!\n Refreshing fp32 model params from the model's fp16 params instead. This may incur some precision loss."
493+ )
494+ self .refresh_fp32_params ()
461495
462496 def __repr__ (self ):
463497 return repr (self .optimizer )
0 commit comments