@@ -52,6 +52,11 @@ def __init__(self, args):
5252 self .args .num_attention_heads = self .get_value (self .model_cfg , ["num_attention_heads" , "n_head" ])
5353 self .args .hidden_size = self .model_cfg ["hidden_size" ]
5454
55+ self .reduce_dialogue_repetition = int (os .environ .get ("REDUCE_DIALOGUE_REPETITION" , 0 ))
56+
57+ self .max_stop_seqs_num = int (os .getenv ("MAX_STOP_SEQS_NUM" , 5 ))
58+ self .stop_seqs_max_len = int (os .getenv ("STOP_SEQS_MAX_LEN" , 8 ))
59+
5560 self .nranks = dist .get_world_size ()
5661 self .init_dist_env ()
5762 self .rank = fleet .worker_index ()
@@ -246,6 +251,19 @@ def init_inputs(self):
246251 self .share_inputs ['free_list_len' ] = paddle .full (
247252 shape = [1 ], fill_value = self .free_list_len , dtype = "int32" )
248253
254+ self .share_inputs ['stop_seqs_len' ] = paddle .full (shape = [self .max_stop_seqs_num ,],
255+ fill_value = 0 ,
256+ dtype = "int32" )
257+ self .share_inputs ['stop_seqs' ] = paddle .full (shape = [self .max_stop_seqs_num , self .stop_seqs_max_len ],
258+ fill_value = - 1 ,
259+ dtype = "int64" )
260+
261+ if self .reduce_dialogue_repetition :
262+ self .share_inputs ["first_token_ids" ] = paddle .full (
263+ shape = [self .args .max_batch_size , 1 ], fill_value = - 1 , dtype = "int64" )
264+ self .share_inputs ["ori_seq_lens_encoder" ] = paddle .full (
265+ shape = [self .args .max_batch_size , 1 ], fill_value = 0 , dtype = "int32" )
266+
249267 def dy_input_preprocess (self , tasks ):
250268 """
251269 dynamic insertion
@@ -279,6 +297,10 @@ def dy_input_preprocess(self, tasks):
279297 self .share_inputs ['max_length' ][idx :idx + 1 ] = max_dec_len
280298 self .share_inputs ['stop_flags' ][idx :idx + 1 ] = False
281299
300+ if self .reduce_dialogue_repetition :
301+ self .share_inputs ['first_token_ids' ][idx :idx + 1 ] = self .share_inputs ['input_ids' ][idx :idx + 1 , :1 ]
302+ self .share_inputs ["ori_seq_lens_encoder" ][idx :idx + 1 ] = length
303+
282304 if "infer_seed" in task :
283305 self .share_inputs ['infer_seed' ][idx :idx + 1 ] = task ['infer_seed' ]
284306
@@ -288,6 +310,14 @@ def dy_input_preprocess(self, tasks):
288310 self .share_inputs ["block_tables" ][idx :idx + 1 , :encoder_block_num ] = np .array (
289311 task ['block_tables' ], dtype = "int32" )
290312
313+ if "stop_seqs_len" in task :
314+ stop_seqs_num = len (task ["stop_seqs_len" ])
315+ for i in range (stop_seqs_num , self .max_stop_seqs_num ):
316+ task ["stop_seqs_len" ].append (0 )
317+ self .share_inputs ['stop_seqs_len' ][:] = np .array (
318+ task ["stop_seqs_len" ], dtype = "int32" )
319+ self .share_inputs ['stop_seqs' ][:stop_seqs_num , :len (task ['stop_seqs' ][0 ])] = np .array (
320+ task ["stop_seqs" ], dtype = "int64" )
291321 def step_cuda (self , seq_lens_this_time ):
292322 """
293323 step cuda
@@ -474,6 +504,11 @@ def _init_predictor(self):
474504 config .switch_ir_optim (False )
475505 config .enable_use_gpu (100 , device_id )
476506
507+ pir_flag = int (os .environ .get ("FLAGS_enable_pir_api" , 0 ))
508+ if pir_flag == 1 :
509+ config .enable_new_executor ()
510+ config .enable_new_ir ()
511+
477512 # distributed config
478513 if self .mp_degree > 1 :
479514 trainer_endpoints = fleet .worker_endpoints ()
0 commit comments