@@ -54,6 +54,9 @@ def __init__(self, args):
5454
5555 self .reduce_dialogue_repetition = int (os .environ .get ("REDUCE_DIALOGUE_REPETITION" , 0 ))
5656
57+ self .max_stop_seqs_num = int (os .getenv ("MAX_STOP_SEQS_NUM" , 5 ))
58+ self .stop_seqs_max_len = int (os .getenv ("STOP_SEQS_MAX_LEN" , 8 ))
59+
5760 self .nranks = dist .get_world_size ()
5861 self .init_dist_env ()
5962 self .rank = fleet .worker_index ()
@@ -248,6 +251,13 @@ def init_inputs(self):
248251 self .share_inputs ['free_list_len' ] = paddle .full (
249252 shape = [1 ], fill_value = self .free_list_len , dtype = "int32" )
250253
254+ self .share_inputs ['stop_seqs_len' ] = paddle .full (shape = [max_stop_seqs_num ,],
255+ fill_value = 0 ,
256+ dtype = "int32" )
257+ self .share_inputs ['stop_seqs' ] = paddle .full (shape = [max_stop_seqs_num , stop_seqs_max_len ],
258+ fill_value = - 1 ,
259+ dtype = "int64" )
260+
251261 if self .reduce_dialogue_repetition :
252262 self .share_inputs ["first_token_ids" ] = paddle .full (
253263 shape = [self .args .max_batch_size , 1 ], fill_value = - 1 , dtype = "int64" )
@@ -300,6 +310,14 @@ def dy_input_preprocess(self, tasks):
300310 self .share_inputs ["block_tables" ][idx :idx + 1 , :encoder_block_num ] = np .array (
301311 task ['block_tables' ], dtype = "int32" )
302312
313+ if "stop_seqs_len" in task :
314+ stop_seqs_num = len (task ["stop_seqs_len" ])
315+ for i in range (stop_seqs_num , max_stop_seqs_num ):
316+ task ["stop_seqs_len" ].append (0 )
317+ share_inputs ['stop_seqs_len' ][:] = np .array (
318+ task ["stop_seqs_len" ], dtype = "int32" )
319+ share_inputs ['stop_seqs' ][:stop_seqs_num , :len (task ['stop_seqs' ][0 ])] = np .array (
320+ task ["stop_seqs" ], dtype = "int64" )
303321 def step_cuda (self , seq_lens_this_time ):
304322 """
305323 step cuda
@@ -486,6 +504,11 @@ def _init_predictor(self):
486504 config .switch_ir_optim (False )
487505 config .enable_use_gpu (100 , device_id )
488506
507+ pir_flag = int (os .environ .get ("FLAGS_enable_pir_api" , 0 ))
508+ if pir_flag == 1 :
509+ config .enable_new_executor ()
510+ config .enable_new_ir ()
511+
489512 # distributed config
490513 if self .mp_degree > 1 :
491514 trainer_endpoints = fleet .worker_endpoints ()
0 commit comments