Skip to content

Commit 996951e

Browse files
authored
* bf16 initial commit * Update engine.py * update split_half_float_double_csr dtypes * update to bf16 communication (make flag optional) * Update requirements-sparse_attn.txt * add compressed bf16 allreduce * add compressed bf16 allreduce * Update __init__.py * Update engine.py * Update __init__.py * Update engine.py * zero1 + bf16 * zero 2 + bf16 * pipe parallel + bf16 * pipe parallel + bf16 * partition activations + bf16
1 parent 3389e4f commit 996951e

10 files changed

Lines changed: 268 additions & 106 deletions

File tree

deepspeed/runtime/activation_checkpointing/checkpointing.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def get_partition_size(item):
253253
return int(partition_size)
254254

255255

256-
def get_full_inputs(tensors, device=None):
256+
def get_full_inputs(tensors, device=None, fp32_comm=False):
257257
inputs = []
258258
num_args = int(len(tensors) / 2)
259259
for i in range(num_args - 1):
@@ -274,9 +274,14 @@ def get_full_inputs(tensors, device=None):
274274
part_i = flat_tensor.narrow(0, partition_size * i, partition_size)
275275
if i == mp_rank:
276276
part_i.copy_(item)
277+
if fp32_comm:
278+
part_i = part_i.float()
277279
partitions.append(part_i)
278280
if mp_group is not None:
279281
dist.all_gather(partitions, partitions[mp_rank], group=mp_group)
282+
if fp32_comm:
283+
for i in range(mp_size):
284+
partitions[i] = partitions[i].to(item.dtype)
280285
input_tensor = flat_tensor.view(list(size.numpy()))
281286
item.data = input_tensor.data
282287

@@ -599,9 +604,14 @@ def backward(ctx, *grads):
599604
global cuda_device, transport_stream, PARTITION_ACTIVATIONS
600605

601606
if PARTITION_ACTIVATIONS:
607+
if ctx.saved_tensors and ctx.saved_tensors[0].dtype == torch.bfloat16:
608+
FP32_COMM = True
609+
else:
610+
FP32_COMM = False
602611
# with torch.cuda.stream(transport_stream):
603612
inputs = get_full_inputs(ctx.saved_tensors,
604-
device=cuda_device if PA_TO_CPU else None)
613+
device=cuda_device if PA_TO_CPU else None,
614+
fp32_comm=FP32_COMM)
605615
detached_inputs = detach_variable(inputs)
606616
else:
607617
inputs = ctx.saved_tensors

deepspeed/runtime/comm/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .compressed_ar import compressed_all_reduce
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# python -m torch.distributed.launch --nproc_per_node=1 24_bit_allreduce.py
2+
3+
import torch
4+
import os
5+
import cupy
6+
from torch.utils.dlpack import to_dlpack
7+
from torch.utils.dlpack import from_dlpack
8+
9+
version = torch.__version__.split('.')
10+
TORCH_VERSION_MAJOR = int(version[0])
11+
TORCH_VERSION_MINOR = int(version[1])
12+
if TORCH_VERSION_MAJOR < 1 or (TORCH_VERSION_MAJOR >= 1 and TORCH_VERSION_MINOR < 9):
13+
compressed_all_reduce = compressed_all_reduce_cupy
14+
else:
15+
compressed_all_reduce = compressed_all_reduce_torch
16+
17+
def torch2cupy(tensor):
18+
return cupy.fromDlpack(to_dlpack(tensor))
19+
20+
21+
def cupy2torch(cupy_tensor):
22+
return from_dlpack(cupy_tensor.toDlpack())
23+
24+
25+
def decompose_cupy(tensor):
26+
mantissa, exponent = cupy.frexp(torch2cupy(tensor.float()))
27+
return cupy2torch(mantissa).half(), cupy2torch(exponent).to(torch.int8)
28+
29+
30+
def decompose(t):
31+
if TORCH_VERSION_MAJOR < 1 or (TORCH_VERSION_MAJOR >= 1 and TORCH_VERSION_MINOR < 9):
32+
raise Exception('Torch version >= 1.9.0 needed for 24_bit_allreduce.decompose')
33+
mantissa, exponent = torch.frexp(t.float())
34+
return mantissa.half(), exponent.to(torch.int8)
35+
36+
37+
def reconstruct(mantissa, exponent, original_dtype=torch.bfloat16):
38+
return torch.ldexp(mantissa, exponent).to(original_dtype)
39+
40+
41+
def compressed_all_reduce_torch(tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
42+
original_dtype = tensor.dtype
43+
m, e = decompose(tensor)
44+
torch.distributed.all_reduce(m, op=op, group=group, async_op=async_op)
45+
torch.distributed.all_reduce(e, op=op, group=group, async_op=async_op)
46+
return reconstruct(m, e, original_dtype)
47+
48+
49+
def compressed_all_reduce_cupy(tensor, op=torch.distributed.ReduceOp.SUM, group=None, async_op=False):
50+
original_dtype = tensor.dtype
51+
m, e = decompose_cupy(tensor)
52+
torch.distributed.all_reduce(m, op=op, group=group, async_op=async_op)
53+
torch.distributed.all_reduce(e, op=op, group=group, async_op=async_op)
54+
return reconstruct(m, e, original_dtype)

deepspeed/runtime/config.py

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,18 @@ def get_fp16_enabled(param_dict):
9494
return False
9595

9696

97+
def get_fp16_type(param_dict):
98+
if get_fp16_enabled(param_dict):
99+
return get_scalar_param(param_dict[FP16], FP16_TYPE, FP16_TYPE_DEFAULT)
100+
else:
101+
return "fp32"
102+
103+
97104
def get_loss_scale(param_dict):
98105
if get_fp16_enabled(param_dict):
106+
if get_fp16_type(param_dict) == "bfloat16":
107+
# default loss scale to 1.0 if dtype == bf16, as loss scaling isn't needed
108+
return 1.0
99109
return get_scalar_param(param_dict[FP16],
100110
FP16_LOSS_SCALE,
101111
FP16_LOSS_SCALE_DEFAULT)
@@ -111,7 +121,7 @@ def get_initial_dynamic_scale(param_dict):
111121
else:
112122
initial_scale_power = FP16_INITIAL_SCALE_POWER_DEFAULT
113123

114-
return 2**initial_scale_power
124+
return 2 ** initial_scale_power
115125

116126

117127
def get_dynamic_loss_scale_args(param_dict):
@@ -138,7 +148,7 @@ def get_dynamic_loss_scale_args(param_dict):
138148
FP16_MIN_LOSS_SCALE,
139149
FP16_MIN_LOSS_SCALE_DEFAULT)
140150
loss_scale_args = {
141-
INITIAL_LOSS_SCALE: 2**init_scale,
151+
INITIAL_LOSS_SCALE: 2 ** init_scale,
142152
SCALE_WINDOW: scale_window,
143153
DELAYED_SHIFT: delayed_shift,
144154
MIN_LOSS_SCALE: min_loss_scale
@@ -168,6 +178,9 @@ def get_zero_reduce_scatter(param_dict):
168178

169179

170180
def get_allreduce_always_fp32(param_dict):
181+
if get_fp16_type(param_dict) == "bfloat16":
182+
# default allreduce_always_fp32 to True if dtype == bf16, as nccl can't communicate bf16 tensors
183+
return get_scalar_param(param_dict, FP32_ALLREDUCE, FP32_ALLREDUCE_DEFAULT_BF16)
171184
return get_scalar_param(param_dict, FP32_ALLREDUCE, FP32_ALLREDUCE_DEFAULT)
172185

173186

@@ -409,7 +422,7 @@ def get_optimizer_gradient_clipping(param_dict):
409422

410423
def get_optimizer_legacy_fusion(param_dict):
411424
if OPTIMIZER in param_dict.keys() and \
412-
LEGACY_FUSION in param_dict[OPTIMIZER].keys():
425+
LEGACY_FUSION in param_dict[OPTIMIZER].keys():
413426
return param_dict[OPTIMIZER][LEGACY_FUSION]
414427
else:
415428
return LEGACY_FUSION_DEFAULT
@@ -496,7 +509,7 @@ def get_checkpoint_tag_validation_mode(checkpoint_params):
496509
return tag_validation_mode
497510
else:
498511
raise DeepSpeedConfigError("Checkpoint config contains invalid tag_validation " \
499-
f"value of {tag_validation_mode}, expecting one of {CHECKPOINT_TAG_VALIDATION_MODES}")
512+
f"value of {tag_validation_mode}, expecting one of {CHECKPOINT_TAG_VALIDATION_MODES}")
500513

501514

502515
'''Write deepspeed config files by modifying basic templates.
@@ -568,11 +581,11 @@ def __init__(self, json_file, mpu=None, param_dict=None):
568581
]
569582
if any(map(lambda t: t in self._param_dict, batch_params)):
570583
raise ElasticityConfigError("One or more batch related parameters were found in your " \
571-
f"ds_config ({TRAIN_BATCH_SIZE}, {TRAIN_MICRO_BATCH_SIZE_PER_GPU}, and/or " \
572-
f"{GRADIENT_ACCUMULATION_STEPS}). These parameters *will not be used* since " \
573-
"elastic training is enabled, which takes control of these parameters. " \
574-
"If you want to supress this error (the parameters will be silently ignored) " \
575-
f"please set {IGNORE_NON_ELASTIC_BATCH_INFO}':true in your elasticity config.")
584+
f"ds_config ({TRAIN_BATCH_SIZE}, {TRAIN_MICRO_BATCH_SIZE_PER_GPU}, and/or " \
585+
f"{GRADIENT_ACCUMULATION_STEPS}). These parameters *will not be used* since " \
586+
"elastic training is enabled, which takes control of these parameters. " \
587+
"If you want to supress this error (the parameters will be silently ignored) " \
588+
f"please set {IGNORE_NON_ELASTIC_BATCH_INFO}':true in your elasticity config.")
576589

577590
# micro_bsz * world_size * gas = total_batch_size
578591
# gas = total_batch_size // (micro_bsz * world_size)
@@ -581,13 +594,13 @@ def __init__(self, json_file, mpu=None, param_dict=None):
581594

582595
if TRAIN_BATCH_SIZE in self._param_dict:
583596
logger.warning("[Elasticity] overriding training_batch_size: " \
584-
f"{self._param_dict[TRAIN_BATCH_SIZE]} -> {final_batch_size}")
597+
f"{self._param_dict[TRAIN_BATCH_SIZE]} -> {final_batch_size}")
585598
if TRAIN_MICRO_BATCH_SIZE_PER_GPU in self._param_dict:
586599
logger.warning("[Elasticity] overriding train_micro_batch_size_per_gpu: " \
587-
f"{self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU]} -> {micro_batch_size}")
600+
f"{self._param_dict[TRAIN_MICRO_BATCH_SIZE_PER_GPU]} -> {micro_batch_size}")
588601
if GRADIENT_ACCUMULATION_STEPS in self._param_dict:
589-
logger.warning("[Elasticity] overriding gradient_accumulation_steps: "\
590-
f"{self._param_dict[GRADIENT_ACCUMULATION_STEPS]} -> {gradient_accu_steps}")
602+
logger.warning("[Elasticity] overriding gradient_accumulation_steps: " \
603+
f"{self._param_dict[GRADIENT_ACCUMULATION_STEPS]} -> {gradient_accu_steps}")
591604

592605
logger.info(f"[Elasticity] valid GPU counts: {valid_gpus}")
593606

@@ -622,6 +635,9 @@ def _initialize_params(self, param_dict):
622635

623636
self.gradient_clipping = get_gradient_clipping(param_dict)
624637
self.fp16_enabled = get_fp16_enabled(param_dict)
638+
self.fp16_type = get_fp16_type(param_dict)
639+
self.precision = PRECISION_TYPES[self.fp16_type]
640+
625641
self.amp_enabled = get_amp_enabled(param_dict)
626642
self.amp_params = get_amp_params(param_dict)
627643
self.loss_scale = get_loss_scale(param_dict)
@@ -630,7 +646,7 @@ def _initialize_params(self, param_dict):
630646

631647
self.optimizer_name = get_optimizer_name(param_dict)
632648
if self.optimizer_name is not None and \
633-
self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS:
649+
self.optimizer_name.lower() in DEEPSPEED_OPTIMIZERS:
634650
self.optimizer_name = self.optimizer_name.lower()
635651

636652
self.optimizer_params = get_optimizer_params(param_dict)
@@ -678,54 +694,54 @@ def _batch_assertion(self):
678694
f'Gradient accumulation steps: {grad_acc} has to be greater than 0'
679695

680696
assert train_batch == micro_batch * grad_acc * self.world_size, \
681-
(f'Check batch related parameters. train_batch_size is not equal'
682-
' to micro_batch_per_gpu * gradient_acc_step * world_size'
683-
f'{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}')
697+
(f'Check batch related parameters. train_batch_size is not equal'
698+
' to micro_batch_per_gpu * gradient_acc_step * world_size'
699+
f'{train_batch} != {micro_batch} * {grad_acc} * {self.world_size}')
684700

685701
def _set_batch_related_parameters(self):
686702

687703
train_batch = self.train_batch_size
688704
micro_batch = self.train_micro_batch_size_per_gpu
689705
grad_acc = self.gradient_accumulation_steps
690706

691-
#all values are provided nothing needs to be set
707+
# all values are provided nothing needs to be set
692708
if train_batch is not None and \
693-
micro_batch is not None and \
694-
grad_acc is not None:
709+
micro_batch is not None and \
710+
grad_acc is not None:
695711
return
696712

697-
#global_accumulation_steps needs to be set
713+
# global_accumulation_steps needs to be set
698714
elif train_batch is not None and \
699-
micro_batch is not None:
715+
micro_batch is not None:
700716
grad_acc = train_batch // micro_batch
701717
grad_acc //= self.world_size
702718
self.gradient_accumulation_steps = grad_acc
703719

704-
#micro_batch_per_gpu needs to be set
720+
# micro_batch_per_gpu needs to be set
705721
elif train_batch is not None and \
706-
grad_acc is not None:
722+
grad_acc is not None:
707723
micro_batch = train_batch // self.world_size
708724
micro_batch //= grad_acc
709725
self.train_micro_batch_size_per_gpu = micro_batch
710726

711-
#train_batch_size needs to be set
727+
# train_batch_size needs to be set
712728
elif micro_batch is not None and \
713-
grad_acc is not None:
729+
grad_acc is not None:
714730
train_batch_size = micro_batch * grad_acc
715731
train_batch_size *= self.world_size
716732
self.train_batch_size = train_batch_size
717733

718-
#gradient_accumulation_steps and micro_batch_per_gpus is set
734+
# gradient_accumulation_steps and micro_batch_per_gpus is set
719735
elif train_batch is not None:
720736
self.gradient_accumulation_steps = 1
721737
self.train_micro_batch_size_per_gpu = train_batch // self.world_size
722738

723-
#train_batch_size and gradient_accumulation_step is set
739+
# train_batch_size and gradient_accumulation_step is set
724740
elif micro_batch is not None:
725741
self.train_batch_size = micro_batch * self.world_size
726742
self.gradient_accumulation_steps = 1
727743

728-
#either none of the three parameters are provided or just gradient_accumulation_step is provided
744+
# either none of the three parameters are provided or just gradient_accumulation_step is provided
729745
else:
730746
assert False, \
731747
'Either train_batch_size or micro_batch_per_gpu needs to be provided'
@@ -755,17 +771,19 @@ def print(self, name):
755771
':'))))
756772

757773
def _do_error_check(self):
758-
assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(TRAIN_MICRO_BATCH_SIZE_PER_GPU)
774+
assert self.train_micro_batch_size_per_gpu, "DeepSpeedConfig: {} is not defined".format(
775+
TRAIN_MICRO_BATCH_SIZE_PER_GPU)
759776

760777
assert self.gradient_accumulation_steps, "DeepSpeedConfig: {} is not defined".format(
761778
GRADIENT_ACCUMULATION_STEPS)
762779

763780
if self.zero_enabled:
764781
assert self.fp16_enabled, "DeepSpeedConfig: ZeRO is only supported if fp16 is enabled"
765-
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(MAX_STAGE_ZERO_OPTIMIZATION)
766-
#if self.zero_config.cpu_offload is True:
782+
assert self.zero_optimization_stage <= MAX_STAGE_ZERO_OPTIMIZATION, "DeepSpeedConfig: Maximum supported ZeRO stage is {}".format(
783+
MAX_STAGE_ZERO_OPTIMIZATION)
784+
# if self.zero_config.cpu_offload is True:
767785
# assert self.zero_optimization_stage == ZERO_OPTIMIZATION_GRADIENTS, "DeepSpeedConfig: cpu-offload supported ZeRO stage is {}".format(ZERO_OPTIMIZATION_GRADIENTS)
768-
#assert self.gradient_accumulation_steps == 1, "DeepSpeedConfig: {}is not supported for {}".format(GRADIENT_ACCUMULATION_STEPS, ZERO_OPTIMIZATION_CPU_OFFLOAD)
786+
# assert self.gradient_accumulation_steps == 1, "DeepSpeedConfig: {}is not supported for {}".format(GRADIENT_ACCUMULATION_STEPS, ZERO_OPTIMIZATION_CPU_OFFLOAD)
769787

770788
def _do_warning_check(self):
771789
fp16_enabled = self.fp16_enabled or self.zero_enabled
@@ -774,21 +792,21 @@ def _do_warning_check(self):
774792
if vocabulary_size and vocabulary_size % TENSOR_CORE_ALIGN_SIZE != 0:
775793
logger.warning(
776794
"DeepSpeedConfig: vocabulary size {} is not aligned to {}, may import tensor core utilization."
777-
.format(vocabulary_size,
778-
TENSOR_CORE_ALIGN_SIZE))
795+
.format(vocabulary_size,
796+
TENSOR_CORE_ALIGN_SIZE))
779797

780798
if self.optimizer_params is not None and \
781-
MAX_GRAD_NORM in self.optimizer_params.keys() and \
799+
MAX_GRAD_NORM in self.optimizer_params.keys() and \
782800
self.optimizer_params[MAX_GRAD_NORM] > 0:
783801
if fp16_enabled:
784802
if self.global_rank == 0:
785803
logger.warning(
786804
'DeepSpeedConfig: In FP16 mode, DeepSpeed will pass {}:{} to FP16 wrapper'
787-
.format(MAX_GRAD_NORM,
788-
self.optimizer_params[MAX_GRAD_NORM]))
805+
.format(MAX_GRAD_NORM,
806+
self.optimizer_params[MAX_GRAD_NORM]))
789807
else:
790808
if self.global_rank == 0:
791809
logger.warning(
792810
'DeepSpeedConfig: In FP32 mode, DeepSpeed does not permit MAX_GRAD_NORM ({}) > 0, setting to zero'
793-
.format(self.optimizer_params[MAX_GRAD_NORM]))
811+
.format(self.optimizer_params[MAX_GRAD_NORM]))
794812
self.optimizer_params[MAX_GRAD_NORM] = 0.0

deepspeed/runtime/constants.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Copyright (c) Microsoft Corporation
33
Licensed under the MIT license.
44
"""
5+
import torch
56

67
#############################################
78
# Routes
@@ -128,6 +129,18 @@
128129
FP16_ENABLED = "enabled"
129130
FP16_ENABLED_DEFAULT = False
130131

132+
FP16_TYPE = "type"
133+
FP16_TYPE_DEFAULT = "fp16"
134+
PRECISION_TYPES = {
135+
"fp32": torch.float32,
136+
"float32": torch.float32,
137+
"float": torch.float32,
138+
"fp16": torch.half,
139+
"float16": torch.half,
140+
"half": torch.half,
141+
"bfloat16": torch.bfloat16
142+
}
143+
131144
# FP16 loss scale, zero means using dynamic scaling
132145
FP16_LOSS_SCALE = "loss_scale"
133146
FP16_LOSS_SCALE_DEFAULT = 0
@@ -189,6 +202,7 @@
189202
'''
190203
FP32_ALLREDUCE = "fp32_allreduce"
191204
FP32_ALLREDUCE_DEFAULT = False
205+
FP32_ALLREDUCE_DEFAULT_BF16 = True # if dtype is bf16 - default to fp32 communication
192206

193207
#########################################
194208
# Scale/predivide gradients before allreduce

0 commit comments

Comments
 (0)