@@ -338,13 +338,37 @@ def get_pad_id(self):
338338 return self .tokenizer .eos_token
339339 return self .tokenizer .pad_token_id
340340
341+ def pad_batch_data (self , insts , pad_id = 0 , return_seq_len = False , return_array = True , pad_style = "right" ):
342+ """Pad the instances to the max sequence length in batch."""
343+ if len (insts ) == 0 :
344+ padded_insts = np .array ([[]], dtype = np .int64 ) if return_array else [[]]
345+ if return_seq_len :
346+ seq_len = np .array ([], dtype = np .int64 ) if return_array else []
347+ return padded_insts , seq_len
348+ return padded_insts
349+
350+ max_len = max (map (len , insts ))
351+ if pad_style == "left" :
352+ padded_insts = [[pad_id ] * (max_len - len (inst )) + list (inst ) for inst in insts ]
353+ else :
354+ padded_insts = [list (inst ) + [pad_id ] * (max_len - len (inst )) for inst in insts ]
355+ if return_array :
356+ padded_insts = np .array (padded_insts , dtype = np .int64 ).reshape ([- 1 , max_len ])
357+
358+ if return_seq_len :
359+ seq_len = [len (inst ) for inst in insts ]
360+ if return_array :
361+ seq_len = np .array (seq_len , dtype = np .int64 ).reshape (- 1 , 1 )
362+ return padded_insts , seq_len
363+ return padded_insts
364+
341365 def update_stop_seq (self , request ):
342366 """
343367 Update stop sequences from request.
344368 """
345- stop_seqs = [[ 2 ], [ 100273 ] ]
369+ stop_seqs = []
346370 for seq in request .get ("stop_sequences" , []):
347- if seq != self ._get_eos_token_id () :
371+ if seq != self .tokenizer . eos_token_id :
348372 stop_seqs .append (self .tokenizer .convert_tokens_to_ids (self .tokenizer .tokenize (seq )))
349373 request ["stop_seqs" ], request ["stop_seqs_len" ] = self .pad_batch_data (
350374 stop_seqs ,
0 commit comments