|
26 | 26 | from deepspeed.runtime.config import DeepSpeedConfig, DEEPSPEED_OPTIMIZERS, \ |
27 | 27 | ADAM_OPTIMIZER, ADAMW_OPTIMIZER, LAMB_OPTIMIZER, ONEBIT_ADAM_OPTIMIZER, ONEBIT_LAMB_OPTIMIZER, \ |
28 | 28 | TORCH_ADAM_PARAM, ADAM_W_MODE, ADAM_W_MODE_DEFAULT |
29 | | -from deepspeed.runtime.comm import compressed_all_reduce |
30 | 29 |
|
31 | 30 | from deepspeed.runtime.dataloader import DeepSpeedDataLoader |
32 | 31 | from deepspeed.runtime.constants import \ |
@@ -139,7 +138,6 @@ def __init__(self, |
139 | 138 | self.store_gradients = False |
140 | 139 | self.store_gradients_cpu = False |
141 | 140 | self.stored_gradients = None |
142 | | - self.bf16_compressed_allreduce = False # hardcode for now - it's not really working |
143 | 141 |
|
144 | 142 | if dist_init_required is None: |
145 | 143 | dist_init_required = not dist.is_initialized() |
@@ -1292,29 +1290,23 @@ def allreduce_bucket(self, bucket): |
1292 | 1290 |
|
1293 | 1291 | tensor_to_allreduce = tensor |
1294 | 1292 |
|
1295 | | - if self.allreduce_always_fp32() and not self.bf16_compressed_allreduce: |
| 1293 | + if self.allreduce_always_fp32(): |
1296 | 1294 | tensor_to_allreduce = tensor.float() |
1297 | 1295 |
|
1298 | 1296 | if self.postscale_gradients(): |
1299 | 1297 | if self.gradient_predivide_factor() != 1.0: |
1300 | 1298 | tensor_to_allreduce.mul_(1. / self.gradient_predivide_factor()) |
1301 | | - if self.bf16_compressed_allreduce and self.precision() == torch.bfloat16: |
1302 | | - compressed_all_reduce(tensor_to_allreduce, group=self.data_parallel_group) |
1303 | | - else: |
1304 | | - dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group) |
| 1299 | + dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group) |
1305 | 1300 |
|
1306 | 1301 | if self.gradient_average: |
1307 | 1302 | if self.gradient_predivide_factor() != self.dp_world_size: |
1308 | 1303 | tensor_to_allreduce.mul_(self.gradient_predivide_factor() / |
1309 | 1304 | self.dp_world_size) |
1310 | 1305 | else: |
1311 | 1306 | tensor_to_allreduce.div_(self.dp_world_size) |
1312 | | - if self.bf16_compressed_allreduce and self.precision() == torch.bfloat16: |
1313 | | - compressed_all_reduce(tensor_to_allreduce, group=self.data_parallel_group) |
1314 | | - else: |
1315 | | - dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group) |
| 1307 | + dist.all_reduce(tensor_to_allreduce, group=self.data_parallel_group) |
1316 | 1308 |
|
1317 | | - if self.allreduce_always_fp32() and tensor is not tensor_to_allreduce and not self.bf16_compressed_allreduce: |
| 1309 | + if self.allreduce_always_fp32() and tensor is not tensor_to_allreduce: |
1318 | 1310 | tensor.copy_(tensor_to_allreduce) |
1319 | 1311 |
|
1320 | 1312 | return tensor |
|
0 commit comments