Skip to content

Commit 027485a

Browse files
committed
make fp32 allreduce optional
1 parent 87fbb8f commit 027485a

3 files changed

Lines changed: 10 additions & 10 deletions

File tree

deepspeed/runtime/engine.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,6 @@ def _configure_fp16_optimizer(self, optimizer):
803803
def _configure_zero_optimizer(self, optimizer):
804804
zero_stage = self.zero_optimization_stage()
805805
log_dist('Creating fp16 ZeRO stage {} optimizer'.format(zero_stage), ranks=[0])
806-
assert not self.allreduce_always_fp32(), "ZeRO does not support 'fp32_allreduce': true"
807806
timers = self.timers if self.wall_clock_breakdown() else None
808807

809808
if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
@@ -819,7 +818,8 @@ def _configure_zero_optimizer(self, optimizer):
819818
dp_process_group=self.data_parallel_group,
820819
elastic_checkpoint=self.zero_elastic_checkpoint(),
821820
mpu=self.mpu,
822-
precision=self.precision())
821+
precision=self.precision(),
822+
fp32_allreduce=self.allreduce_always_fp32)
823823
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
824824
optimizer = FP16_DeepSpeedZeroOptimizer(
825825
optimizer,
@@ -839,7 +839,8 @@ def _configure_zero_optimizer(self, optimizer):
839839
postscale_gradients=self.postscale_gradients(),
840840
gradient_predivide_factor=self.gradient_predivide_factor(),
841841
gradient_accumulation_steps=self.gradient_accumulation_steps(),
842-
precision=self.precision())
842+
precision=self.precision(),
843+
fp32_allreduce=self.allreduce_always_fp32)
843844
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
844845
print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
845846
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
@@ -867,7 +868,8 @@ def _configure_zero_optimizer(self, optimizer):
867868
postscale_gradients=self.postscale_gradients(),
868869
gradient_predivide_factor=self.gradient_predivide_factor(),
869870
gradient_accumulation_steps=self.gradient_accumulation_steps(),
870-
aio_config=self.aio_config())
871+
aio_config=self.aio_config(),
872+
fp32_allreduce=self.allreduce_always_fp32)
871873

872874
else:
873875
raise NotImplementedError("ZeRO stage {} not implemented".format(zero_stage))

deepspeed/runtime/zero/stage1.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ def __init__(self,
121121
clip_grad=0.0,
122122
max_elements_per_comm=5e8,
123123
elastic_checkpoint=True,
124-
precision=torch.half):
124+
precision=torch.half,
125+
fp32_allreduce=False):
125126

126127
# Load pre-built or JIT compile (un)flatten ops
127128
util_ops = UtilsBuilder().load()
@@ -130,10 +131,7 @@ def __init__(self,
130131

131132
# set precision
132133
self.precision = precision
133-
if self.precision == torch.bfloat16:
134-
self.fp32_allreduce = True
135-
else:
136-
self.fp32_allreduce = False
134+
self.fp32_allreduce = fp32_allreduce
137135

138136
if dp_process_group is not None and partition_size is not None:
139137
raise ValueError("Cannot specify both dp_process_group "

deepspeed/runtime/zero/stage2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def __init__(self,
115115
raise SystemError("Cannot use fp16 without CUDA.")
116116
self.optimizer = init_optimizer
117117
self.precision = precision
118-
self.fp32_allreduce = True if self.precision == torch.bfloat16 else allreduce_always_fp32
118+
self.fp32_allreduce = allreduce_always_fp32
119119

120120

121121
# Load pre-built or JIT compile (un)flatten ops

0 commit comments

Comments
 (0)