@@ -803,6 +803,7 @@ def _configure_fp16_optimizer(self, optimizer):
803803 def _configure_zero_optimizer (self , optimizer ):
804804 zero_stage = self .zero_optimization_stage ()
805805 log_dist ('Creating fp16 ZeRO stage {} optimizer' .format (zero_stage ), ranks = [0 ])
806+ assert not self .allreduce_always_fp32 (), "ZeRO does not support 'fp32_allreduce': true"
806807 timers = self .timers if self .wall_clock_breakdown () else None
807808
808809 if zero_stage == ZERO_OPTIMIZATION_OPTIMIZER_STATES :
@@ -818,8 +819,7 @@ def _configure_zero_optimizer(self, optimizer):
818819 dp_process_group = self .data_parallel_group ,
819820 elastic_checkpoint = self .zero_elastic_checkpoint (),
820821 mpu = self .mpu ,
821- precision = self .precision (),
822- fp32_allreduce = self .allreduce_always_fp32 )
822+ precision = self .precision ())
823823 elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS :
824824 optimizer = FP16_DeepSpeedZeroOptimizer (
825825 optimizer ,
@@ -839,8 +839,7 @@ def _configure_zero_optimizer(self, optimizer):
839839 postscale_gradients = self .postscale_gradients (),
840840 gradient_predivide_factor = self .gradient_predivide_factor (),
841841 gradient_accumulation_steps = self .gradient_accumulation_steps (),
842- precision = self .precision (),
843- fp32_allreduce = self .allreduce_always_fp32 )
842+ precision = self .precision ())
844843 elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS :
845844 print ("Initializing ZeRO Stage 3" ) if dist .get_rank () == 0 else None
846845 from deepspeed .runtime .zero .stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
@@ -868,8 +867,7 @@ def _configure_zero_optimizer(self, optimizer):
868867 postscale_gradients = self .postscale_gradients (),
869868 gradient_predivide_factor = self .gradient_predivide_factor (),
870869 gradient_accumulation_steps = self .gradient_accumulation_steps (),
871- aio_config = self .aio_config (),
872- fp32_allreduce = self .allreduce_always_fp32 )
870+ aio_config = self .aio_config ())
873871
874872 else :
875873 raise NotImplementedError ("ZeRO stage {} not implemented" .format (zero_stage ))
0 commit comments