Skip to content

Commit b41340f

Browse files
committed
Fix for pytorch 2.0 compatibility
1 parent 810d4a5 commit b41340f

3 files changed

Lines changed: 7 additions & 5 deletions

File tree

deepspeed/runtime/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515

1616
import torch
1717
import torch.distributed as dist
18-
from torch._six import inf
1918
import torch.distributed as dist
2019

20+
try:
21+
from torch._six import inf as inf
22+
except ModuleNotFoundError:
23+
from torch import inf as inf
24+
2125
from deepspeed.utils import logger
2226
from numpy import prod
2327

deepspeed/runtime/zero/stage2.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,12 @@
66
from torch.distributed.distributed_c10d import _get_global_rank
77
import torch.distributed as dist
88
import math
9-
from torch._six import inf
109
from torch.autograd import Variable
1110

1211
import collections
1312

1413
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
15-
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
14+
from deepspeed.runtime.utils import inf, see_memory_usage, is_model_parallel_parameter
1615
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
1716
from deepspeed.ops.adam import DeepSpeedCPUAdam
1817
from deepspeed.ops.op_builder import UtilsBuilder

deepspeed/runtime/zero/stage3.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111
from torch.distributed.distributed_c10d import _get_global_rank
1212
import torch.distributed as dist
1313
import math
14-
from torch._six import inf
1514
from torch.autograd import Variable
1615

1716
from deepspeed.utils.logging import logger
1817
from deepspeed.runtime.fp16.loss_scaler import LossScaler, DynamicLossScaler
19-
from deepspeed.runtime.utils import see_memory_usage, is_model_parallel_parameter
18+
from deepspeed.runtime.utils import inf, see_memory_usage, is_model_parallel_parameter
2019
from deepspeed.runtime.zero.partition_parameters import *
2120
from deepspeed.runtime.zero.partition_parameters import _init_external_params
2221
from deepspeed.runtime.zero.constants import ZERO_OPTIMIZATION_WEIGHTS

0 commit comments

Comments
 (0)