Skip to content

Commit c814fca

Browse files
authored
Update engine.py
1 parent 75de46c commit c814fca

1 file changed

Lines changed: 16 additions & 5 deletions

File tree

deepspeed/runtime/engine.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
2222
from deepspeed.runtime.zero.utils import is_zero_supported_optimizer
2323
from deepspeed.runtime.activation_checkpointing import checkpointing as activation_checkpointing
24-
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer, FP16_FUSED_SUPPORTED_OPTIMIZERS, is_fp16_fused_supported_optimizer
24+
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
2525
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
2626
from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \
2727
ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \
@@ -397,6 +397,9 @@ def zero_gather_fp16_weights_on_model_save(self):
397397
def fp16_enabled(self):
398398
return self._config.fp16_enabled
399399

400+
def precision(self):
401+
return self._config.precision
402+
400403
def amp_enabled(self):
401404
return self._config.amp_enabled
402405

@@ -569,14 +572,18 @@ def is_replicated(p):
569572

570573
for p in self.module.parameters():
571574
if torch.is_tensor(p) and is_replicated(p):
575+
if self.precision() == torch.bfloat16:
576+
p = p.float()
572577
dist.broadcast(p,
573578
self.broadcast_src_rank,
574579
group=self.data_parallel_group)
580+
if self.precision() == torch.bfloat16:
581+
p = p.bfloat16()
575582

576583
def _configure_distributed_model(self, model):
577584
self.module = model
578585
if self.fp16_enabled():
579-
self.module.half()
586+
self.module.to(self.precision())
580587

581588
if not self.dont_change_device:
582589
self.module.to(self.device)
@@ -714,7 +721,8 @@ def _configure_fp16_optimizer(self, optimizer):
714721
initial_dynamic_scale = self.initial_dynamic_scale()
715722
dynamic_loss_args = self.dynamic_loss_scale_args()
716723
clip_grad = self.gradient_clipping()
717-
if is_fp16_fused_supported_optimizer(optimizer):
724+
if isinstance(optimizer,
725+
FusedAdam) or self.optimizer_name() == ONEBIT_ADAM_OPTIMIZER:
718726
if self.dynamic_loss_scale():
719727
log_dist('Creating fp16 optimizer with dynamic loss scale', ranks=[0])
720728
timers = self.timers if self.wall_clock_breakdown() else None
@@ -772,7 +780,8 @@ def _configure_zero_optimizer(self, optimizer):
772780
max_elements_per_comm=self.zero_reduce_bucket_size(),
773781
dp_process_group=self.data_parallel_group,
774782
elastic_checkpoint=self.zero_elastic_checkpoint(),
775-
mpu=self.mpu)
783+
mpu=self.mpu,
784+
precision=self.precision())
776785
elif zero_stage == ZERO_OPTIMIZATION_GRADIENTS:
777786
optimizer = FP16_DeepSpeedZeroOptimizer(
778787
optimizer,
@@ -791,7 +800,8 @@ def _configure_zero_optimizer(self, optimizer):
791800
mpu=self.mpu,
792801
postscale_gradients=self.postscale_gradients(),
793802
gradient_predivide_factor=self.gradient_predivide_factor(),
794-
gradient_accumulation_steps=self.gradient_accumulation_steps())
803+
gradient_accumulation_steps=self.gradient_accumulation_steps(),
804+
precision=self.precision())
795805
elif zero_stage == ZERO_OPTIMIZATION_WEIGHTS:
796806
print("Initializing ZeRO Stage 3") if dist.get_rank() == 0 else None
797807
from deepspeed.runtime.zero.stage3 import FP16_DeepSpeedZeroOptimizer_Stage3
@@ -979,6 +989,7 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
979989

980990
# Communicate only at gradient accumulation boundaries
981991
elif self.is_gradient_accumulation_boundary():
992+
# TODO: communication in fp16 / fp32
982993
if self.zero_optimization_stage(
983994
) == ZERO_OPTIMIZATION_OPTIMIZER_STATES and self.zero_reduce_scatter():
984995
self.optimizer.reduce_scatter_gradients(

0 commit comments

Comments
 (0)