@@ -42,6 +42,7 @@ class LayerSpec:
4242 LayerSpec(torch.nn.Linear, self.hidden_hidden, self.out_dim)]
4343 ]
4444 """
45+
4546 def __init__ (self , typename , * module_args , ** module_kwargs ):
4647 self .typename = typename
4748 self .module_args = module_args
@@ -187,10 +188,10 @@ def forward(self, inputs):
187188 self .tied_weight_attrs = {}
188189
189190 # Offset the random seed by the stage ID.
190- #newseed = torch.cuda.initial_seed() + self._grid.get_stage_id()
191- #ds_utils.set_random_seed(newseed)
191+ # newseed = torch.cuda.initial_seed() + self._grid.get_stage_id()
192+ # ds_utils.set_random_seed(newseed)
192193
193- #with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
194+ # with torch.random.fork_rng(devices=[torch.cuda.current_device()]):
194195 self ._build ()
195196 self .to (f'cuda:{ self .local_rank } ' )
196197
@@ -199,7 +200,9 @@ def forward(self, inputs):
199200
200201 self .activation_checkpoint_interval = activation_checkpoint_interval
201202 self .activation_checkpoint_func = activation_checkpoint_func
202- self .set_checkpointable_layers (checkpointable_layers )
203+ if checkpointable_layers is not None :
204+ assert isinstance (checkpointable_layers , list )
205+ self .checkpointable_layers = checkpointable_layers
203206
204207 def _build (self ):
205208 specs = self ._layer_specs
@@ -341,7 +344,7 @@ def exec_func(*inputs):
341344 # Since we either pass tensors or tuples of tensors without unpacking, we
342345 # need to be careful not to double-wrap tensors with tuple.
343346 if not isinstance (x , tuple ):
344- x = (x , )
347+ x = (x ,)
345348
346349 if self ._is_checkpointable (funcs ):
347350 x = self .activation_checkpoint_func (
@@ -400,7 +403,7 @@ def _partition_layers(self, method='uniform'):
400403 name = layer .__name__
401404 except AttributeError :
402405 pass
403- print (f' { idx + start :2d} : { name } ' )
406+ print (f' { idx + start :2d} : { name } ' )
404407 if self .loss_fn :
405408 try :
406409 print (f' loss: { self .loss_fn .__name__ } ' )
@@ -564,28 +567,20 @@ def load_state_dir(self, load_dir, strict=True):
564567 model_ckpt_path = self .ckpt_layer_path (load_dir , idx )
565568 layer .load_state_dict (torch .load (model_ckpt_path ,
566569 map_location = lambda storage ,
567- loc : storage ),
570+ loc : storage ),
568571 strict = strict )
569572 if self ._grid .data_parallel_id == 0 :
570573 logger .info (
571- f'RANK={ self .global_rank } Loaded layer={ idx + layer_offset } file={ model_ckpt_path } '
574+ f'RANK={ self .global_rank } Loaded layer={ idx + layer_offset } file={ model_ckpt_path } '
572575 )
573576
574577 self ._synchronize_tied_weights ()
575578
576- def set_checkpointable_layers (self , string ):
577- """
578- Allows you to pass a string which defines which layers are checkpointable
579- """
580- self .checkpointable_layers = string
581-
582579 def _is_checkpointable (self , funcs ):
583- if self .__class__ .__name__ == 'GPT2ModelPipe' :
580+ if self .checkpointable_layers is not None :
581+ return all (f .__class__ .__name__ in self .checkpointable_layers for f in funcs )
582+ elif self .__class__ .__name__ == 'GPT2ModelPipe' :
584583 return all ('ParallelTransformerLayerPipe' in f .__class__ .__name__
585584 for f in funcs )
586- elif self .checkpointable_layers is not None :
587- ret = all (self .checkpointable_layers in f .__class__ .__name__
588- for f in funcs )
589- return ret
590585 params = [f .parameters () for f in funcs if isinstance (f , torch .nn .Module )]
591586 return any (len (list (p )) > 0 for p in params )
0 commit comments