@@ -108,12 +108,8 @@ def step(self,
108108 bias_correction = 1 if group ['bias_correction' ] else 0
109109 beta1 , beta2 = group ['betas' ]
110110
111- # assume same step across group now to simplify things
112- # per parameter step can be easily support by making it tensor, or pass list into kernel
113- if 'step' in group :
114- group ['step' ] += 1
115- else :
116- group ['step' ] = 1
111+ if 'step' not in group :
112+ group ['step' ] = 0
117113
118114 # create lists for multi-tensor apply
119115 g_16 , p_16 , m_16 , v_16 = [], [], [], []
@@ -130,6 +126,10 @@ def step(self,
130126 state = self .state [p ]
131127 # State initialization
132128 if len (state ) == 0 :
129+ # DeepSpeed ZeRO 3 processes each subgroup a time, so we need to keep tracking step count for each tensor separately.
130+ # While this is not an issue for ZeRO 1 & 2, since they apply a single optimizatin step to the whole param group at the same time.
131+ # In order to keep backward compatibility for the existing checkpoints, we use group['state'] to initialize state['step'] if it exists.
132+ state ['step' ] = group .get ('step' , 0 )
133133 # Exponential moving average of gradient values
134134 state ['exp_avg' ] = torch .zeros_like (p .data )
135135 # Exponential moving average of squared gradient values
@@ -149,6 +149,7 @@ def step(self,
149149 raise RuntimeError ('FusedAdam only support fp16 and fp32.' )
150150
151151 if (len (g_16 ) > 0 ):
152+ state ['step' ] += 1
152153 multi_tensor_applier (self .multi_tensor_adam ,
153154 self ._dummy_overflow_buf ,
154155 [g_16 ,
@@ -159,11 +160,12 @@ def step(self,
159160 beta1 ,
160161 beta2 ,
161162 group ['eps' ],
162- group ['step' ],
163+ state ['step' ],
163164 self .adam_w_mode ,
164165 bias_correction ,
165166 group ['weight_decay' ])
166167 if (len (g_32 ) > 0 ):
168+ state ['step' ] += 1
167169 multi_tensor_applier (self .multi_tensor_adam ,
168170 self ._dummy_overflow_buf ,
169171 [g_32 ,
@@ -174,7 +176,7 @@ def step(self,
174176 beta1 ,
175177 beta2 ,
176178 group ['eps' ],
177- group ['step' ],
179+ state ['step' ],
178180 self .adam_w_mode ,
179181 bias_correction ,
180182 group ['weight_decay' ])
0 commit comments