2626import paddle .distributed as dist
2727import paddle .distributed .fleet as fleet
2828from paddlenlp .trl .llm_utils import get_rotary_position_embedding
29- from paddlenlp_ops import step_paddle , speculate_step_paddle
29+ from paddlenlp_ops import step_paddle
3030from server .data .processor import DataProcessor
3131from server .engine .config import Config
3232from paddlenlp .experimental .transformers import InferenceWithReferenceProposer
@@ -47,7 +47,8 @@ def __init__(self, args):
4747
4848 self .config = Config ()
4949 self .model_cfg = self .config .get_model_config ()
50- self .is_speculate_decoding = self .model_cfg .get ("speculate_method" ) is not None
50+ self .speculate_config = self .config .get_speculate_config ()
51+ self .is_speculate_decoding = self .speculate_config .speculate_method is not None
5152 self .format_print_configuration ()
5253
5354 self .args .num_layers = self .get_value (self .model_cfg , ["num_hidden_layers" , "num_layers" ])
@@ -70,15 +71,13 @@ def __init__(self, args):
7071 self .init_inputs ()
7172
7273 if self .is_speculate_decoding :
73- logger .info (f'Using speculating decoding, method: { self .model_cfg [ " speculate_method" ] } .' )
74- if self .model_cfg [ " speculate_method" ] == "inference_with_reference" :
74+ logger .info (f'Using speculating decoding, method: { self .speculate_config . speculate_method } .' )
75+ if self .speculate_config . speculate_method == "inference_with_reference" :
7576 self .proposer = InferenceWithReferenceProposer (
76- self .model_cfg [ " speculate_max_draft_token_num" ] ,
77- self .model_cfg [ " speculate_max_ngram_size" ] ,
77+ self .speculate_config . speculate_max_draft_token_num ,
78+ self .speculate_config . speculate_max_ngram_size ,
7879 self .args .max_batch_size ,
7980 self .args .max_seq_len )
80- else :
81- raise NotImplementedError (f'Not support { self .model_cfg ["speculate_method" ]} , only support inference_with_reference now.' )
8281 else :
8382 self .proposer = None
8483
@@ -281,14 +280,14 @@ def init_inputs(self):
281280 # speculate decoding input
282281 if self .is_speculate_decoding :
283282 self .share_inputs ["accept_tokens" ] = paddle .full (
284- shape = [self .args .max_batch_size , self .model_cfg [ " speculate_max_draft_token_num" ] + 1 ], fill_value = 0 , dtype = "int64"
283+ shape = [self .args .max_batch_size , self .speculate_config . speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
285284 )
286285 self .share_inputs ["accept_num" ] = paddle .full (shape = [self .args .max_batch_size ], fill_value = 0 , dtype = "int32" )
287286 self .share_inputs ["draft_tokens" ] = paddle .full (
288- shape = [self .args .max_batch_size , self .model_cfg [ " speculate_max_draft_token_num" ] + 1 ], fill_value = 0 , dtype = "int64"
287+ shape = [self .args .max_batch_size , self .speculate_config . speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
289288 )
290289 self .share_inputs ["actual_draft_token_num" ] = paddle .full (
291- shape = [self .args .max_batch_size ], fill_value = self .model_cfg [ " speculate_max_draft_token_num" ] , dtype = "int32"
290+ shape = [self .args .max_batch_size ], fill_value = self .speculate_config . speculate_max_draft_token_num , dtype = "int32"
292291 )
293292
294293 def dy_input_preprocess (self , tasks ):
@@ -347,42 +346,33 @@ def dy_input_preprocess(self, tasks):
347346 task ["stop_seqs" ], dtype = "int64" )
348347
349348 if self .is_speculate_decoding :
350- self .share_inputs ["draft_tokens" ][idx :idx + 1 ] = np .zeros ([self .model_cfg [ " speculate_max_draft_token_num" ] + 1 ])
351- self .share_inputs ["actual_draft_token_num" ][idx :idx + 1 ] = np .array ([self .model_cfg [ " speculate_max_draft_token_num" ] ])
349+ self .share_inputs ["draft_tokens" ][idx :idx + 1 ] = np .zeros ([self .speculate_config . speculate_max_draft_token_num + 1 ])
350+ self .share_inputs ["actual_draft_token_num" ][idx :idx + 1 ] = np .array ([self .speculate_config . speculate_max_draft_token_num ])
352351
353352 def step_cuda (self , seq_lens_this_time ):
354353 """
355354 step cuda
356355 """
357- if not self .is_speculate_decoding :
358- step_paddle (self .share_inputs ['stop_flags' ], seq_lens_this_time ,
359- self .share_inputs ['step_seq_lens_encoder' ],
360- self .share_inputs ['seq_lens_encoder' ],
361- self .share_inputs ['seq_lens_decoder' ], self .share_inputs ["block_tables" ],
362- self .share_inputs ['encoder_block_lens' ],
363- self .share_inputs ["is_block_step" ], self .share_inputs ['step_block_list' ],
364- self .share_inputs ['step_lens' ], self .share_inputs ['recover_block_list' ],
365- self .share_inputs ['recover_lens' ], self .share_inputs ['need_block_list' ],
366- self .share_inputs ['need_block_len' ], self .share_inputs ['used_list_len' ],
367- self .share_inputs ['free_list' ], self .share_inputs ['free_list_len' ],
368- self .share_inputs ['input_ids' ], self .share_inputs ['pre_ids' ],
369- self .share_inputs ['step_idx' ], self .share_inputs ['next_tokens' ],
370- self .args .block_size , self .args .enc_dec_block_num , self .args .first_token_id )
356+ # whether speculate decoding
357+ if self .is_speculate_decoding :
358+ speculate_step_token_num = self .speculate_config .speculate_max_draft_token_num + 1
371359 else :
372- speculate_step_paddle (self .share_inputs ['stop_flags' ], seq_lens_this_time ,
373- self .share_inputs ['step_seq_lens_encoder' ],
374- self .share_inputs ['seq_lens_encoder' ],
375- self .share_inputs ['seq_lens_decoder' ], self .share_inputs ["block_tables" ],
376- self .share_inputs ['encoder_block_lens' ],
377- self .share_inputs ["is_block_step" ], self .share_inputs ['step_block_list' ],
378- self .share_inputs ['step_lens' ], self .share_inputs ['recover_block_list' ],
379- self .share_inputs ['recover_lens' ], self .share_inputs ['need_block_list' ],
380- self .share_inputs ['need_block_len' ], self .share_inputs ['used_list_len' ],
381- self .share_inputs ['free_list' ], self .share_inputs ['free_list_len' ],
382- self .share_inputs ['input_ids' ], self .share_inputs ['pre_ids' ],
383- self .share_inputs ['step_idx' ], self .share_inputs ['next_tokens' ],
384- self .args .block_size , self .args .enc_dec_block_num , self .args .first_token_id ,
385- self .model_cfg ["speculate_max_draft_token_num" ])
360+ speculate_step_token_num = 0
361+
362+ step_paddle (self .share_inputs ['stop_flags' ], seq_lens_this_time ,
363+ self .share_inputs ['step_seq_lens_encoder' ],
364+ self .share_inputs ['seq_lens_encoder' ],
365+ self .share_inputs ['seq_lens_decoder' ], self .share_inputs ["block_tables" ],
366+ self .share_inputs ['encoder_block_lens' ],
367+ self .share_inputs ["is_block_step" ], self .share_inputs ['step_block_list' ],
368+ self .share_inputs ['step_lens' ], self .share_inputs ['recover_block_list' ],
369+ self .share_inputs ['recover_lens' ], self .share_inputs ['need_block_list' ],
370+ self .share_inputs ['need_block_len' ], self .share_inputs ['used_list_len' ],
371+ self .share_inputs ['free_list' ], self .share_inputs ['free_list_len' ],
372+ self .share_inputs ['input_ids' ], self .share_inputs ['pre_ids' ],
373+ self .share_inputs ['step_idx' ], self .share_inputs ['next_tokens' ],
374+ self .args .block_size , self .args .enc_dec_block_num , self .args .first_token_id
375+ speculate_step_token_num )
386376
387377 def initialize_engine_ready_check_flag (self ):
388378 """
0 commit comments