Skip to content

Commit b467d11

Browse files
authored
Update compressed_ar.py
1 parent 996951e commit b467d11

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

deepspeed/runtime/comm/compressed_ar.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,6 @@
66
from torch.utils.dlpack import to_dlpack
77
from torch.utils.dlpack import from_dlpack
88

9-
version = torch.__version__.split('.')
10-
TORCH_VERSION_MAJOR = int(version[0])
11-
TORCH_VERSION_MINOR = int(version[1])
12-
if TORCH_VERSION_MAJOR < 1 or (TORCH_VERSION_MAJOR >= 1 and TORCH_VERSION_MINOR < 9):
13-
compressed_all_reduce = compressed_all_reduce_cupy
14-
else:
15-
compressed_all_reduce = compressed_all_reduce_torch
16-
179
def torch2cupy(tensor):
1810
return cupy.fromDlpack(to_dlpack(tensor))
1911

@@ -52,3 +44,11 @@ def compressed_all_reduce_cupy(tensor, op=torch.distributed.ReduceOp.SUM, group=
5244
torch.distributed.all_reduce(m, op=op, group=group, async_op=async_op)
5345
torch.distributed.all_reduce(e, op=op, group=group, async_op=async_op)
5446
return reconstruct(m, e, original_dtype)
47+
48+
version = torch.__version__.split('.')
49+
TORCH_VERSION_MAJOR = int(version[0])
50+
TORCH_VERSION_MINOR = int(version[1])
51+
if TORCH_VERSION_MAJOR < 1 or (TORCH_VERSION_MAJOR >= 1 and TORCH_VERSION_MINOR < 9):
52+
compressed_all_reduce = compressed_all_reduce_cupy
53+
else:
54+
compressed_all_reduce = compressed_all_reduce_torch

0 commit comments

Comments
 (0)