@@ -183,7 +183,9 @@ def __init__(self, *super_args, **super_kwargs):
183183
184184 if self .is_last_stage ():
185185 self .loss_model = self .module .loss_fn
186-
186+
187+ self .has_attention_mask = self .module .__class__ .__name__ == 'GPT2ModelPipe'
188+
187189 # Initialize pipeline communicators. Just send a 0.
188190 if is_even (self .stage_id ):
189191 if not self .is_last_stage ():
@@ -196,6 +198,10 @@ def __init__(self, *super_args, **super_kwargs):
196198 if not self .is_last_stage ():
197199 p2p .send (self .loss , self .next_stage )
198200
201+ def set_has_attention_mask (self , value ):
202+ assert isinstance (value , boolean )
203+ self .has_attention_mask = value
204+
199205 def _build_data_iter (self , dataset ):
200206 sampler = torch .utils .data .distributed .DistributedSampler (
201207 dataset ,
@@ -919,7 +925,7 @@ def _exec_send_activations(self, buffer_id):
919925 # NCCL does not like to send torch.BoolTensor types, so cast the mask to half().
920926 # We could do char, but with half() we can eventually flatten with other fp16
921927 # messages (TODO)
922- if self .module . __class__ . __name__ == 'GPT2ModelPipe' :
928+ if self .has_attention_mask :
923929 outputs = list (outputs )
924930 outputs [- 1 ] = outputs [- 1 ].half ()
925931 outputs = tuple (outputs )
@@ -938,7 +944,7 @@ def _exec_send_activations(self, buffer_id):
938944 f'{ type (outputs )} ' )
939945
940946 # Restore the boolean tensor
941- if self .module . __class__ . __name__ == 'GPT2ModelPipe' :
947+ if self .has_attention_mask :
942948 outputs = list (outputs )
943949 outputs [- 1 ] = outputs [- 1 ].bool ()
944950 outputs = tuple (outputs )
@@ -968,7 +974,7 @@ def _exec_send_grads(self, buffer_id):
968974 # a grad that needs to be communicated. We free the buffer immediately
969975 # after, so no need to restore it. The receiver also has a hack that skips
970976 # the recv. This is because NCCL does not let us send torch.BoolTensor :-(.
971- if self .module . __class__ . __name__ == 'GPT2ModelPipe' :
977+ if self .has_attention_mask :
972978 inputs = list (inputs )
973979 inputs .pop ()
974980 inputs = tuple (inputs )
@@ -1030,7 +1036,7 @@ def _exec_recv_activations(self, buffer_id):
10301036
10311037 # NCCL does not like to send torch.BoolTensor types, so un-cast the
10321038 # attention mask
1033- if self .module . __class__ . __name__ == 'GPT2ModelPipe' :
1039+ if self .has_attention_mask :
10341040 recvd [- 1 ] = recvd [- 1 ].bool ()
10351041
10361042 recvd = tuple (recvd )
0 commit comments