@@ -803,7 +803,6 @@ def _configure_fp16_optimizer(self, optimizer):
803803 def _configure_zero_optimizer (self , optimizer ):
804804 zero_stage = self .zero_optimization_stage ()
805805 log_dist ('Creating fp16 ZeRO stage {} optimizer' .format (zero_stage ), ranks = [0 ])
806- assert not self .allreduce_always_fp32 (), "ZeRO does not support 'fp32_allreduce': true"
807806 timers = self .timers if self .wall_clock_breakdown () else None
808807
809808 if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES :
@@ -819,7 +818,8 @@ def _configure_zero_optimizer(self, optimizer):
819818 dp_process_group = self .data_parallel_group ,
820819 elastic_checkpoint = self .zero_elastic_checkpoint (),
821820 mpu = self .mpu ,
822- precision = self .precision ())
821+ precision = self .precision (),
822+ fp32_allreduce = self .allreduce_always_fp32 )
823823 elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS :
824824 optimizer = FP16_DeepSpeedZeroOptimizer (
825825 optimizer ,
@@ -839,7 +839,8 @@ def _configure_zero_optimizer(self, optimizer):
839839 postscale_gradients = self .postscale_gradients (),
840840 gradient_predivide_factor = self .gradient_predivide_factor (),
841841 gradient_accumulation_steps = self .gradient_accumulation_steps (),
842- precision = self .precision ())
842+ precision = self .precision (),
843+ fp32_allreduce = self .allreduce_always_fp32 )
843844 elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS :
844845 print ("Initializing ZeRO Stage 3" ) if dist .get_rank () == 0 else None
845846 from deepspeed .runtime .zero .stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
@@ -867,7 +868,8 @@ def _configure_zero_optimizer(self, optimizer):
867868 postscale_gradients = self .postscale_gradients (),
868869 gradient_predivide_factor = self .gradient_predivide_factor (),
869870 gradient_accumulation_steps = self .gradient_accumulation_steps (),
870- aio_config = self .aio_config ())
871+ aio_config = self .aio_config (),
872+ fp32_allreduce = self .allreduce_always_fp32 )
871873
872874 else :
873875 raise NotImplementedError ("ZeRO stage {} not implemented" .format (zero_stage ))
0 commit comments