Skip to content

Commit 2d6a0d3

Browse files
authored
make checkpointable layers a list
1 parent 2d78b37 commit 2d6a0d3

1 file changed

Lines changed: 14 additions & 19 deletions

File tree

deepspeed/runtime/pipe/module.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class LayerSpec:
4242
LayerSpec(torch.nn.Linear, self.hidden_hidden, self.out_dim)]
4343
]
4444
"""
45+
4546
def __init__(self, typename, *module_args, **module_kwargs):
4647
self.typename = typename
4748
self.module_args = module_args
@@ -187,10 +188,10 @@ def forward(self, inputs):
187188
self.tied_weight_attrs = {}
188189

189190
# Offset the random seed by the stage ID.
190-
#newseed = torch.cuda.initial_seed() + self._grid.get_stage_id()
191-
#ds_utils.set_random_seed(newseed)
191+
# newseed = torch.cuda.initial_seed() + self._grid.get_stage_id()
192+
# ds_utils.set_random_seed(newseed)
192193

193-
#with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
194+
# with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
194195
self._build()
195196
self.to(f'cuda:{self.local_rank}')
196197

@@ -199,7 +200,9 @@ def forward(self, inputs):
199200

200201
self.activation_checkpoint_interval = activation_checkpoint_interval
201202
self.activation_checkpoint_func = activation_checkpoint_func
202-
self.set_checkpointable_layers(checkpointable_layers)
203+
if checkpointable_layers is not None:
204+
assert isinstance(checkpointable_layers, list)
205+
self.checkpointable_layers = checkpointable_layers
203206

204207
def _build(self):
205208
specs = self._layer_specs
@@ -341,7 +344,7 @@ def exec_func(*inputs):
341344
# Since we either pass tensors or tuples of tensors without unpacking, we
342345
# need to be careful not to double-wrap tensors with tuple.
343346
if not isinstance(x, tuple):
344-
x = (x, )
347+
x = (x,)
345348

346349
if self._is_checkpointable(funcs):
347350
x = self.activation_checkpoint_func(
@@ -400,7 +403,7 @@ def _partition_layers(self, method='uniform'):
400403
name = layer.__name__
401404
except AttributeError:
402405
pass
403-
print(f' {idx+start:2d}: {name}')
406+
print(f' {idx + start:2d}: {name}')
404407
if self.loss_fn:
405408
try:
406409
print(f' loss: {self.loss_fn.__name__}')
@@ -564,28 +567,20 @@ def load_state_dir(self, load_dir, strict=True):
564567
model_ckpt_path = self.ckpt_layer_path(load_dir, idx)
565568
layer.load_state_dict(torch.load(model_ckpt_path,
566569
map_location=lambda storage,
567-
loc: storage),
570+
loc: storage),
568571
strict=strict)
569572
if self._grid.data_parallel_id == 0:
570573
logger.info(
571-
f'RANK={self.global_rank} Loaded layer={idx+layer_offset} file={model_ckpt_path}'
574+
f'RANK={self.global_rank} Loaded layer={idx + layer_offset} file={model_ckpt_path}'
572575
)
573576

574577
self._synchronize_tied_weights()
575578

576-
def set_checkpointable_layers(self, string):
577-
"""
578-
Allows you to pass a string which defines which layers are checkpointable
579-
"""
580-
self.checkpointable_layers = string
581-
582579
def _is_checkpointable(self, funcs):
583-
if self.__class__.__name__ == 'GPT2ModelPipe':
580+
if self.checkpointable_layers is not None:
581+
return all(f.__class__.__name__ in self.checkpointable_layers for f in funcs)
582+
elif self.__class__.__name__ == 'GPT2ModelPipe':
584583
return all('ParallelTransformerLayerPipe' in f.__class__.__name__
585584
for f in funcs)
586-
elif self.checkpointable_layers is not None:
587-
ret = all(self.checkpointable_layers in f.__class__.__name__
588-
for f in funcs)
589-
return ret
590585
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
591586
return any(len(list(p)) > 0 for p in params)

0 commit comments

Comments
 (0)