We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent b467d11 commit 9a384c9Copy full SHA for 9a384c9
1 file changed
deepspeed/runtime/pipe/module.py
@@ -575,5 +575,5 @@ def _is_checkpointable(self, funcs):
575
if self.__class__.__name__ == 'GPT2ModelPipe':
576
return all('ParallelTransformerLayerPipe' in f.__class__.__name__
577
for f in funcs)
578
- return all('TransformerBlock' in f.__class__.__name__
579
- for f in funcs)
+ params = [f.parameters() for f in funcs if isinstance(f, torch.nn.Module)]
+ return any(len(list(p)) > 0 for p in params)
0 commit comments