Skip to content

Commit 9a384c9

Browse files
authored
update 'is_checkpointable'
1 parent b467d11 commit 9a384c9

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

deepspeed/runtime/pipe/module.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,5 +575,5 @@ def _is_checkpointable(self, funcs):
575575
if self.__class__.__name__ == 'GPT2ModelPipe':
576576
return all('ParallelTransformerLayerPipe' in f.__class__.__name__
577577
for f in funcs)
578-
return all('TransformerBlock' in f.__class__.__name__
579-
for f in funcs)
578+
params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
579+
return any(len(list(p)) > 0 for p in params)

0 commit comments

Comments
 (0)