Skip to content

Commit c745f3d

Browse files
committed
Revert "make fp32 allreduce optional"
This reverts commit 027485a. It was breaking neox training for a reason I can't yet figure out
1 parent 027485a commit c745f3d

3 files changed

Lines changed: 10 additions & 10 deletions

File tree

deepspeed/runtime/engine.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,7 @@ def _configure_fp16_optimizer(self, optimizer):
803803
def _configure_zero_optimizer(self, optimizer):
804804
zero_stage = self.zero_optimization_stage()
805805
log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage), ranks=[0])
806+
assert not self.allreduce_always_fp32(), "ZeRO does not support 'fp32_allreduce': true"
806807
timers = self.timers if self.wall_clock_breakdown() else None
807808

808809
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
@@ -818,8 +819,7 @@ def _configure_zero_optimizer(self, optimizer):
818819
dp_process_group=self.data_parallel_group,
819820
elastic_checkpoint=self.zero_elastic_checkpoint(),
820821
mpu=self.mpu,
821-
precision=self.precision(),
822-
fp32_allreduce=self.allreduce_always_fp32)
822+
precision=self.precision())
823823
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
824824
optimizer = FP16_DeepSpeedZeroOptimizer(
825825
optimizer,
@@ -839,8 +839,7 @@ def _configure_zero_optimizer(self, optimizer):
839839
postscale_gradients=self.postscale_gradients(),
840840
gradient_predivide_factor=self.gradient_predivide_factor(),
841841
gradient_accumulation_steps=self.gradient_accumulation_steps(),
842-
precision=self.precision(),
843-
fp32_allreduce=self.allreduce_always_fp32)
842+
precision=self.precision())
844843
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
845844
print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
846845
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
@@ -868,8 +867,7 @@ def _configure_zero_optimizer(self, optimizer):
868867
postscale_gradients=self.postscale_gradients(),
869868
gradient_predivide_factor=self.gradient_predivide_factor(),
870869
gradient_accumulation_steps=self.gradient_accumulation_steps(),
871-
aio_config=self.aio_config(),
872-
fp32_allreduce=self.allreduce_always_fp32)
870+
aio_config=self.aio_config())
873871

874872
else:
875873
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))

deepspeed/runtime/zero/stage1.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ def __init__(self,
121121
clip_grad=0.0,
122122
max_elements_per_comm=5e8,
123123
elastic_checkpoint=True,
124-
precision=torch.half,
125-
fp32_allreduce=False):
124+
precision=torch.half):
126125

127126
# Load pre-built or JIT compile (un)flatten ops
128127
util_ops = UtilsBuilder().load()
@@ -131,7 +130,10 @@ def __init__(self,
131130

132131
# set precision
133132
self.precision = precision
134-
self.fp32_allreduce = fp32_allreduce
133+
if self.precision == torch.bfloat16:
134+
self.fp32_allreduce = True
135+
else:
136+
self.fp32_allreduce = False
135137

136138
if dp_process_group is not None and partition_size is not None:
137139
raise ValueError("Cannot specify both dp_process_group "

deepspeed/runtime/zero/stage2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __init__(self,
115115
raise SystemError("Cannot use fp16 without CUDA.")
116116
self.optimizer = init_optimizer
117117
self.precision = precision
118-
self.fp32_allreduce = allreduce_always_fp32
118+
self.fp32_allreduce = True if self.precision == torch.bfloat16 else allreduce_always_fp32
119119

120120

121121
# Load pre-built or JIT compile (un)flatten ops

0 commit comments

Comments
 (0)