@@ -69,11 +69,16 @@ def __init__(self, args):
6969 self .init_inputs ()
7070
7171 # whether use speculate decoding
72- if self .config .speculate_method is not None and self .config .speculate_method == "inference_with_reference" :
73- self .proposer = InferenceWithReferenceProposer (
74- self .config .speculate_max_draft_token_num ,
75- self .config .speculate_max_ngram_size ,
76- self .args .max_batch_size )
72+ logger .info (f'speculate_method: { self .config .speculate_method } ' )
73+ if self .config .speculate_method is not None :
74+ if self .config .speculate_method == "inference_with_reference" :
75+ self .proposer = InferenceWithReferenceProposer (
76+ self .model_cfg ["speculate_max_draft_token_num" ],
77+ self .model_cfg ["speculate_max_ngram_size" ],
78+ self .args .max_batch_size ,
79+ self .args .max_seq_len )
80+ else :
81+ raise NotImplementedError (f'Not support { self .config .speculate_method } , only support inference_with_reference now.' )
7782 else :
7883 self .proposer = None
7984
@@ -274,18 +279,17 @@ def init_inputs(self):
274279 self .share_inputs ["ori_seq_lens_encoder" ] = paddle .full (
275280 shape = [self .args .max_batch_size , 1 ], fill_value = 0 , dtype = "int32" )
276281 # speculate decoding input
282+ logger .info (f'Speculative method: { self .config .speculate_method } ' )
277283 if self .config .speculate_method is not None :
278- self .share_inputs ["input_ids_cpu" ] = paddle .full (
279- shape = [self .args .max_batch_size , self .args .max_seq_len ], fill_value = 1 , dtype = 'int64' ).cpu ()
280284 self .share_inputs ["accept_tokens" ] = paddle .full (
281- shape = [self .args .max_batch_size , self .config . speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
285+ shape = [self .args .max_batch_size , self .model_cfg [ " speculate_max_draft_token_num" ] + 1 ], fill_value = 0 , dtype = "int64"
282286 )
283287 self .share_inputs ["accept_num" ] = paddle .full (shape = [self .args .max_batch_size ], fill_value = 0 , dtype = "int32" )
284288 self .share_inputs ["draft_tokens" ] = paddle .full (
285- shape = [self .args .max_batch_size , self .config . speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
289+ shape = [self .args .max_batch_size , self .model_cfg [ " speculate_max_draft_token_num" ] + 1 ], fill_value = 0 , dtype = "int64"
286290 )
287291 self .share_inputs ["actual_draft_token_num" ] = paddle .full (
288- shape = [self .args .max_batch_size ], fill_value = self .config . speculate_max_draft_token_num , dtype = "int32"
292+ shape = [self .args .max_batch_size ], fill_value = self .model_cfg [ " speculate_max_draft_token_num" ] , dtype = "int32"
289293 )
290294
291295 def dy_input_preprocess (self , tasks ):
@@ -344,10 +348,8 @@ def dy_input_preprocess(self, tasks):
344348 task ["stop_seqs" ], dtype = "int64" )
345349 if self .proposer is not None :
346350 if self .config .speculate_method == "inference_with_reference" :
347- speculate_update_input_ids_cpu (self .share_inputs ['input_ids_cpu' ], task ['input_ids' ], idx , self .args .max_seq_len )
348- self .share_inputs ["draft_tokens" ][idx :idx + 1 ] = np .zeros ([self .config .speculate_max_draft_token_num + 1 ])
349- self .share_inputs ["actual_draft_token_num" ][idx :idx + 1 ] = np .array ([self .config .speculate_max_draft_token_num ])
350- self .proposer .update (idx , length )
351+ self .share_inputs ["draft_tokens" ][idx :idx + 1 ] = np .zeros ([self .model_cfg ["speculate_max_draft_token_num" ] + 1 ])
352+ self .share_inputs ["actual_draft_token_num" ][idx :idx + 1 ] = np .array ([self .model_cfg ["speculate_max_draft_token_num" ]])
351353
352354 def step_cuda (self , seq_lens_this_time ):
353355 """
@@ -381,7 +383,7 @@ def step_cuda(self, seq_lens_this_time):
381383 self .share_inputs ['input_ids' ], self .share_inputs ['pre_ids' ],
382384 self .share_inputs ['step_idx' ], self .share_inputs ['next_tokens' ],
383385 self .args .block_size , self .args .enc_dec_block_num , self .args .first_token_id ,
384- self .config . speculate_max_draft_token_num )
386+ self .model_cfg [ " speculate_max_draft_token_num" ] )
385387
386388 def initialize_engine_ready_check_flag (self ):
387389 """
@@ -512,7 +514,6 @@ def run(self):
512514 if self .proposer is not None :
513515 logger .info ("start run proposer" )
514516 logger .info (f'before draft_tokens: { self .share_inputs ["draft_tokens" ]} ' )
515- logger .info (f'before accept_tokens: { self .share_inputs ["accept_tokens" ]} ' )
516517
517518 self .proposer .run (
518519 self .share_inputs ,
@@ -521,19 +522,19 @@ def run(self):
521522 )
522523 logger .info (f'after draft_tokens: { self .share_inputs ["draft_tokens" ]} ' )
523524 logger .info ("finish run proposer" )
524- logger .info (f'input_ids: { self .share_inputs ["input_ids" ]} ' )
525- logger .info (f'input_ids_cpu: { self .share_inputs ["input_ids_cpu" ]} ' )
526- logger .info (f'seq_lens_this_time: { self .share_inputs ["seq_lens_this_time" ]} ' )
527- logger .info (f'seq_lens_encoder: { self .share_inputs ["seq_lens_encoder" ]} ' )
528- logger .info (f'seq_lens_decoder: { self .share_inputs ["seq_lens_decoder" ]} ' )
529- logger .info (f'step_idx: { self .share_inputs ["step_idx" ]} ' )
530- logger .info (f'next_tokens: { self .share_inputs ["next_tokens" ]} ' )
531- logger .info (f'before block_tables: { self .share_inputs ["block_tables" ]} ' )
525+ # logger.info(f'input_ids: {self.share_inputs["input_ids"]}')
526+ # logger.info(f'input_ids_cpu: {self.share_inputs["input_ids_cpu"]}')
527+ # logger.info(f'seq_lens_this_time: {self.share_inputs["seq_lens_this_time"]}')
528+ # logger.info(f'seq_lens_encoder: {self.share_inputs["seq_lens_encoder"]}')
529+ # logger.info(f'seq_lens_decoder: {self.share_inputs["seq_lens_decoder"]}')
530+ # logger.info(f'step_idx: {self.share_inputs["step_idx"]}')
531+ # logger.info(f'next_tokens: {self.share_inputs["next_tokens"]}')
532+ # logger.info(f'before block_tables: {self.share_inputs["block_tables"]}')
532533
533534 self .infer_engine .predictor .run ()
534535 logger .info (f'after accept_tokens: { self .share_inputs ["accept_tokens" ]} ' )
535536 logger .info (f'after accept_num: { self .share_inputs ["accept_num" ]} ' )
536- logger .info (f'after block_tables: { self .share_inputs ["block_tables" ]} ' )
537+ # logger.info(f'after block_tables: {self.share_inputs["block_tables"]}')
537538
538539 self .share_inputs ['infer_seed' ].add_ (infer_seed_increment )
539540 self .share_inputs ['infer_seed' ][:] %= self .MAX_INFER_SEED
0 commit comments