Skip to content

Commit 810d4a5

Browse files
committed
Merge remote-tracking branch 'igor/main' into augment
2 parents 7a514ec + 98f4a6c commit 810d4a5

1 file changed

Lines changed: 10 additions & 8 deletions

File tree

deepspeed/ops/adam/fused_adam.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,8 @@ def step(self,
108108
bias_correction = 1 if group['bias_correction'] else 0
109109
beta1, beta2 = group['betas']
110110

111-
# assume same step across group now to simplify things
112-
# per parameter step can be easily support by making it tensor, or pass list into kernel
113-
if 'step' in group:
114-
group['step'] += 1
115-
else:
116-
group['step'] = 1
111+
if 'step' not in group:
112+
group['step'] = 0
117113

118114
# create lists for multi-tensor apply
119115
g_16, p_16, m_16, v_16 = [], [], [], []
@@ -130,6 +126,10 @@ def step(self,
130126
state = self.state[p]
131127
# State initialization
132128
if len(state) == 0:
129+
# DeepSpeed ZeRO 3 processes each subgroup a time, so we need to keep tracking step count for each tensor separately.
130+
# While this is not an issue for ZeRO 1 & 2, since they apply a single optimizatin step to the whole param group at the same time.
131+
# In order to keep backward compatibility for the existing checkpoints, we use group['state'] to initialize state['step'] if it exists.
132+
state['step'] = group.get('step', 0)
133133
# Exponential moving average of gradient values
134134
state['exp_avg'] = torch.zeros_like(p.data)
135135
# Exponential moving average of squared gradient values
@@ -149,6 +149,7 @@ def step(self,
149149
raise RuntimeError('FusedAdam only support fp16 and fp32.')
150150

151151
if (len(g_16) > 0):
152+
state['step'] += 1
152153
multi_tensor_applier(self.multi_tensor_adam,
153154
self._dummy_overflow_buf,
154155
[g_16,
@@ -159,11 +160,12 @@ def step(self,
159160
beta1,
160161
beta2,
161162
group['eps'],
162-
group['step'],
163+
state['step'],
163164
self.adam_w_mode,
164165
bias_correction,
165166
group['weight_decay'])
166167
if (len(g_32) > 0):
168+
state['step'] += 1
167169
multi_tensor_applier(self.multi_tensor_adam,
168170
self._dummy_overflow_buf,
169171
[g_32,
@@ -174,7 +176,7 @@ def step(self,
174176
beta1,
175177
beta2,
176178
group['eps'],
177-
group['step'],
179+
state['step'],
178180
self.adam_w_mode,
179181
bias_correction,
180182
group['weight_decay'])

0 commit comments

Comments
 (0)