Skip to content

Commit 3b49e7f

Browse files
authored
make attention mask + pipe parallel more flexible
1 parent 9a384c9 commit 3b49e7f

1 file changed

Lines changed: 11 additions & 5 deletions

File tree

deepspeed/runtime/pipe/engine.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,9 @@ def __init__(self, *super_args, **super_kwargs):
183183

184184
if self.is_last_stage():
185185
self.loss_model = self.module.loss_fn
186-
186+
187+
self.has_attention_mask = self.module.__class__.__name__ == 'GPT2ModelPipe'
188+
187189
# Initialize pipeline communicators. Just send a 0.
188190
if is_even(self.stage_id):
189191
if not self.is_last_stage():
@@ -196,6 +198,10 @@ def __init__(self, *super_args, **super_kwargs):
196198
if not self.is_last_stage():
197199
p2p.send(self.loss, self.next_stage)
198200

201+
def set_has_attention_mask(self, value):
202+
assert isinstance(value, boolean)
203+
self.has_attention_mask = value
204+
199205
def _build_data_iter(self, dataset):
200206
sampler = torch.utils.data.distributed.DistributedSampler(
201207
dataset,
@@ -919,7 +925,7 @@ def _exec_send_activations(self, buffer_id):
919925
# NCCL does not like to send torch.BoolTensor types, so cast the mask to half().
920926
# We could do char, but with half() we can eventually flatten with other fp16
921927
# messages (TODO)
922-
if self.module.__class__.__name__ == 'GPT2ModelPipe':
928+
if self.has_attention_mask:
923929
outputs = list(outputs)
924930
outputs[-1] = outputs[-1].half()
925931
outputs = tuple(outputs)
@@ -938,7 +944,7 @@ def _exec_send_activations(self, buffer_id):
938944
f'{type(outputs)}')
939945

940946
# Restore the boolean tensor
941-
if self.module.__class__.__name__ == 'GPT2ModelPipe':
947+
if self.has_attention_mask:
942948
outputs = list(outputs)
943949
outputs[-1] = outputs[-1].bool()
944950
outputs = tuple(outputs)
@@ -968,7 +974,7 @@ def _exec_send_grads(self, buffer_id):
968974
# a grad that needs to be communicated. We free the buffer immediately
969975
# after, so no need to restore it. The receiver also has a hack that skips
970976
# the recv. This is because NCCL does not let us send torch.BoolTensor :-(.
971-
if self.module.__class__.__name__ == 'GPT2ModelPipe':
977+
if self.has_attention_mask:
972978
inputs = list(inputs)
973979
inputs.pop()
974980
inputs = tuple(inputs)
@@ -1030,7 +1036,7 @@ def _exec_recv_activations(self, buffer_id):
10301036

10311037
# NCCL does not like to send torch.BoolTensor types, so un-cast the
10321038
# attention mask
1033-
if self.module.__class__.__name__ == 'GPT2ModelPipe':
1039+
if self.has_attention_mask:
10341040
recvd[-1] = recvd[-1].bool()
10351041

10361042
recvd = tuple(recvd)

0 commit comments

Comments
 (0)