2121from deepspeed .runtime .zero .partition_parameters import ZeroParamStatus
2222from deepspeed .runtime .zero .utils import is_zero_supported_optimizer
2323from deepspeed .runtime .activation_checkpointing import checkpointing as activation_checkpointing
24- from deepspeed .runtime .fp16 .fused_optimizer import FP16_Optimizer , FP16_FUSED_SUPPORTED_OPTIMIZERS , is_fp16_fused_supported_optimizer
24+ from deepspeed .runtime .fp16 .fused_optimizer import FP16_Optimizer
2525from deepspeed .runtime .fp16 .unfused_optimizer import FP16_UnfusedOptimizer
2626from deepspeed .runtime .config import DeepSpeedConfig , DEEPSPEED_OPTIMIZERS , \
2727 ADAM_OPTIMIZER , ADAMW_OPTIMIZER , LAMB_OPTIMIZER , ONEBIT_ADAM_OPTIMIZER , ONEBIT_LAMB_OPTIMIZER , \
@@ -397,6 +397,9 @@ def zero_gather_fp16_weights_on_model_save(self):
397397 def fp16_enabled (self ):
398398 return self ._config .fp16_enabled
399399
400+ def precision (self ):
401+ return self ._config .precision
402+
400403 def amp_enabled (self ):
401404 return self ._config .amp_enabled
402405
@@ -569,14 +572,18 @@ def is_replicated(p):
569572
570573 for p in self .module .parameters ():
571574 if torch .is_tensor (p ) and is_replicated (p ):
575+ if self .precision () == torch .bfloat16 :
576+ p = p .float ()
572577 dist .broadcast (p ,
573578 self .broadcast_src_rank ,
574579 group = self .data_parallel_group )
580+ if self .precision () == torch .bfloat16 :
581+ p = p .bfloat16 ()
575582
576583 def _configure_distributed_model (self , model ):
577584 self .module = model
578585 if self .fp16_enabled ():
579- self .module .half ( )
586+ self .module .to ( self . precision () )
580587
581588 if not self .dont_change_device :
582589 self .module .to (self .device )
@@ -714,7 +721,8 @@ def _configure_fp16_optimizer(self, optimizer):
714721 initial_dynamic_scale = self .initial_dynamic_scale ()
715722 dynamic_loss_args = self .dynamic_loss_scale_args ()
716723 clip_grad = self .gradient_clipping ()
717- if is_fp16_fused_supported_optimizer (optimizer ):
724+ if isinstance (optimizer ,
725+ FusedAdam ) or self .optimizer_name () == ONEBIT_ADAM_OPTIMIZER :
718726 if self .dynamic_loss_scale ():
719727 log_dist ('Creating fp16 optimizer with dynamic loss scale' , ranks = [0 ])
720728 timers = self .timers if self .wall_clock_breakdown () else None
@@ -772,7 +780,8 @@ def _configure_zero_optimizer(self, optimizer):
772780 max_elements_per_comm = self .zero_reduce_bucket_size (),
773781 dp_process_group = self .data_parallel_group ,
774782 elastic_checkpoint = self .zero_elastic_checkpoint (),
775- mpu = self .mpu )
783+ mpu = self .mpu ,
784+ precision = self .precision ())
776785 elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS :
777786 optimizer = FP16_DeepSpeedZeroOptimizer (
778787 optimizer ,
@@ -791,7 +800,8 @@ def _configure_zero_optimizer(self, optimizer):
791800 mpu = self .mpu ,
792801 postscale_gradients = self .postscale_gradients (),
793802 gradient_predivide_factor = self .gradient_predivide_factor (),
794- gradient_accumulation_steps = self .gradient_accumulation_steps ())
803+ gradient_accumulation_steps = self .gradient_accumulation_steps (),
804+ precision = self .precision ())
795805 elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS :
796806 print ("Initializing ZeRO Stage 3" ) if dist .get_rank () == 0 else None
797807 from deepspeed .runtime .zero .stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
@@ -979,6 +989,7 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
979989
980990 # Communicate only at gradient accumulation boundaries
981991 elif self .is_gradient_accumulation_boundary ():
992+ # TODO: communication in fp16 / fp32
982993 if self .zero_optimization_stage (
983994 ) == ZERO_OPTIMIZATION_OPTIMIZER_STATES and self .zero_reduce_scatter ():
984995 self .optimizer .reduce_scatter_gradients (
0 commit comments