@@ -52,6 +52,8 @@ def __init__(self, args):
5252 self .args .num_attention_heads = self .get_value (self .model_cfg , ["num_attention_heads" , "n_head" ])
5353 self .args .hidden_size = self .model_cfg ["hidden_size" ]
5454
55+ self .reduce_dialogue_repetition = int (os .environ .get ("REDUCE_DIALOGUE_REPETITION" , 0 ))
56+
5557 self .nranks = dist .get_world_size ()
5658 self .init_dist_env ()
5759 self .rank = fleet .worker_index ()
@@ -246,6 +248,12 @@ def init_inputs(self):
246248 self .share_inputs ['free_list_len' ] = paddle .full (
247249 shape = [1 ], fill_value = self .free_list_len , dtype = "int32" )
248250
251+ if self .reduce_dialogue_repetition :
252+ self .share_inputs ["first_token_ids" ] = paddle .full (
253+ shape = [self .args .max_batch_size , 1 ], fill_value = - 1 , dtype = "int64" )
254+ self .share_inputs ["ori_seq_lens_encoder" ] = paddle .full (
255+ shape = [self .args .max_batch_size , 1 ], fill_value = 0 , dtype = "int32" )
256+
249257 def dy_input_preprocess (self , tasks ):
250258 """
251259 dynamic insertion
@@ -279,6 +287,10 @@ def dy_input_preprocess(self, tasks):
279287 self .share_inputs ['max_length' ][idx :idx + 1 ] = max_dec_len
280288 self .share_inputs ['stop_flags' ][idx :idx + 1 ] = False
281289
290+ if self .reduce_dialogue_repetition :
291+ self .share_inputs ['first_token_ids' ][idx :idx + 1 ] = self .share_inputs ['input_ids' ][idx :idx + 1 , :1 ]
292+ self .share_inputs ["ori_seq_lens_encoder" ][idx :idx + 1 ] = length
293+
282294 if "infer_seed" in task :
283295 self .share_inputs ['infer_seed' ][idx :idx + 1 ] = task ['infer_seed' ]
284296
0 commit comments