@@ -47,6 +47,7 @@ def __init__(self, args):
4747
4848 self .config = Config ()
4949 self .model_cfg = self .config .get_model_config ()
50+ self .is_speculate_decoding = self .model_cfg .get ("speculate_method" ) is not None
5051 self .format_print_configuration ()
5152
5253 self .args .num_layers = self .get_value (self .model_cfg , ["num_hidden_layers" , "num_layers" ])
@@ -68,16 +69,16 @@ def __init__(self, args):
6869 self .cache_kvs = {}
6970 self .init_inputs ()
7071
71- # whether use speculate decoding
72- if self . config . speculate_method is not None :
73- if self .config . speculate_method == "inference_with_reference" :
72+ if self . is_speculate_decoding :
73+ logger . info ( f'Using speculating decoding, method: { self . model_cfg [ "speculate_method" ] } .' )
74+ if self .model_cfg [ " speculate_method" ] == "inference_with_reference" :
7475 self .proposer = InferenceWithReferenceProposer (
7576 self .model_cfg ["speculate_max_draft_token_num" ],
7677 self .model_cfg ["speculate_max_ngram_size" ],
7778 self .args .max_batch_size ,
7879 self .args .max_seq_len )
7980 else :
80- raise NotImplementedError (f'Not support { self .config . speculate_method } , only support inference_with_reference now.' )
81+ raise NotImplementedError (f'Not support { self .model_cfg [ " speculate_method" ] } , only support inference_with_reference now.' )
8182 else :
8283 self .proposer = None
8384
@@ -278,7 +279,7 @@ def init_inputs(self):
278279 self .share_inputs ["ori_seq_lens_encoder" ] = paddle .full (
279280 shape = [self .args .max_batch_size , 1 ], fill_value = 0 , dtype = "int32" )
280281 # speculate decoding input
281- if self .config . speculate_method is not None :
282+ if self .is_speculate_decoding :
282283 self .share_inputs ["accept_tokens" ] = paddle .full (
283284 shape = [self .args .max_batch_size , self .model_cfg ["speculate_max_draft_token_num" ] + 1 ], fill_value = 0 , dtype = "int64"
284285 )
@@ -344,16 +345,16 @@ def dy_input_preprocess(self, tasks):
344345 task ["stop_seqs_len" ], dtype = "int32" )
345346 self .share_inputs ['stop_seqs' ][:stop_seqs_num , :len (task ['stop_seqs' ][0 ])] = np .array (
346347 task ["stop_seqs" ], dtype = "int64" )
347- if self . proposer is not None :
348- if self .config . speculate_method == "inference_with_reference" :
349- self .share_inputs ["draft_tokens" ][idx :idx + 1 ] = np .zeros ([self .model_cfg ["speculate_max_draft_token_num" ] + 1 ])
350- self .share_inputs ["actual_draft_token_num" ][idx :idx + 1 ] = np .array ([self .model_cfg ["speculate_max_draft_token_num" ]])
348+
349+ if self .is_speculate_decoding :
350+ self .share_inputs ["draft_tokens" ][idx :idx + 1 ] = np .zeros ([self .model_cfg ["speculate_max_draft_token_num" ] + 1 ])
351+ self .share_inputs ["actual_draft_token_num" ][idx :idx + 1 ] = np .array ([self .model_cfg ["speculate_max_draft_token_num" ]])
351352
352353 def step_cuda (self , seq_lens_this_time ):
353354 """
354355 step cuda
355356 """
356- if self .config . speculate_method is None :
357+ if not self .is_speculate_decoding :
357358 step_paddle (self .share_inputs ['stop_flags' ], seq_lens_this_time ,
358359 self .share_inputs ['step_seq_lens_encoder' ],
359360 self .share_inputs ['seq_lens_encoder' ],
0 commit comments