2626import paddle .distributed as dist
2727import paddle .distributed .fleet as fleet
2828from paddlenlp .trl .llm_utils import get_rotary_position_embedding
29- from paddlenlp_ops import step_paddle
29+ from paddlenlp_ops import step_paddle , speculate_step_paddle
3030from server .data .processor import DataProcessor
3131from server .engine .config import Config
32+ from server .engine .proposers import InferenceWithReferenceProposer
3233from server .utils import get_logger
3334from task_queue_manager import TaskQueueManager
3435
@@ -67,6 +68,15 @@ def __init__(self, args):
6768 self .cache_kvs = {}
6869 self .init_inputs ()
6970
71+ # whether use speculate decoding
72+ if self .config .speculate_method is not None and self .config .speculate_method == "inference_with_reference" :
73+ self .proposer = InferenceWithReferenceProposer (
74+ self .config .speculate_max_draft_token_num ,
75+ self .config .speculate_max_ngram_size ,
76+ self .args .max_batch_size )
77+ else :
78+ self .proposer = None
79+
7080 self .infer_queue = TaskQueueManager (rank = self .rank , mp_num = self .nranks , port = self .config .infer_port )
7181
7282 model_rank_path = os .path .join (self .args .model_dir , f"rank_{ self .rank } " )
@@ -263,6 +273,20 @@ def init_inputs(self):
263273 shape = [self .args .max_batch_size , 1 ], fill_value = - 1 , dtype = "int64" )
264274 self .share_inputs ["ori_seq_lens_encoder" ] = paddle .full (
265275 shape = [self .args .max_batch_size , 1 ], fill_value = 0 , dtype = "int32" )
276+ # speculate decoding input
277+ if self .config .speculate_method is not None :
278+ self .share_inputs ["input_ids_cpu" ] = paddle .full (
279+ shape = [self .args .max_batch_size , self .args .max_seq_len ], fill_value = 1 , dtype = 'int64' ).cpu ()
280+ self .share_inputs ["accept_tokens" ] = paddle .full (
281+ shape = [self .args .max_batch_size , self .config .speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
282+ )
283+ self .share_inputs ["accept_num" ] = paddle .full (shape = [self .args .max_batch_size ], fill_value = 0 , dtype = "int32" )
284+ self .share_inputs ["draft_tokens" ] = paddle .full (
285+ shape = [self .args .max_batch_size , self .config .speculate_max_draft_token_num + 1 ], fill_value = 0 , dtype = "int64"
286+ )
287+ self .share_inputs ["actual_draft_token_num" ] = paddle .full (
288+ shape = [self .args .max_batch_size ], fill_value = self .config .speculate_max_draft_token_num , dtype = "int32"
289+ )
266290
267291 def dy_input_preprocess (self , tasks ):
268292 """
@@ -318,23 +342,46 @@ def dy_input_preprocess(self, tasks):
318342 task ["stop_seqs_len" ], dtype = "int32" )
319343 self .share_inputs ['stop_seqs' ][:stop_seqs_num , :len (task ['stop_seqs' ][0 ])] = np .array (
320344 task ["stop_seqs" ], dtype = "int64" )
345+ if self .proposer is not None :
346+ if self .config .speculate_method == "inference_with_reference" :
347+ speculate_update_input_ids_cpu (self .share_inputs ['input_ids_cpu' ], task ['input_ids' ], idx , self .args .max_seq_len )
348+ self .share_inputs ["draft_tokens" ][idx :idx + 1 ] = np .zeros ([self .config .speculate_max_draft_token_num + 1 ])
349+ self .share_inputs ["actual_draft_token_num" ][idx :idx + 1 ] = np .array ([self .config .speculate_max_draft_token_num ])
350+ self .proposer .update (idx , length )
351+
321352 def step_cuda (self , seq_lens_this_time ):
322353 """
323354 step cuda
324355 """
325- step_paddle (self .share_inputs ['stop_flags' ], seq_lens_this_time ,
326- self .share_inputs ['step_seq_lens_encoder' ],
327- self .share_inputs ['seq_lens_encoder' ],
328- self .share_inputs ['seq_lens_decoder' ], self .share_inputs ["block_tables" ],
329- self .share_inputs ['encoder_block_lens' ],
330- self .share_inputs ["is_block_step" ], self .share_inputs ['step_block_list' ],
331- self .share_inputs ['step_lens' ], self .share_inputs ['recover_block_list' ],
332- self .share_inputs ['recover_lens' ], self .share_inputs ['need_block_list' ],
333- self .share_inputs ['need_block_len' ], self .share_inputs ['used_list_len' ],
334- self .share_inputs ['free_list' ], self .share_inputs ['free_list_len' ],
335- self .share_inputs ['input_ids' ], self .share_inputs ['pre_ids' ],
336- self .share_inputs ['step_idx' ], self .share_inputs ['next_tokens' ],
337- self .args .block_size , self .args .enc_dec_block_num , self .args .first_token_id )
356+ if self .config .speculate_method is None :
357+ step_paddle (self .share_inputs ['stop_flags' ], seq_lens_this_time ,
358+ self .share_inputs ['step_seq_lens_encoder' ],
359+ self .share_inputs ['seq_lens_encoder' ],
360+ self .share_inputs ['seq_lens_decoder' ], self .share_inputs ["block_tables" ],
361+ self .share_inputs ['encoder_block_lens' ],
362+ self .share_inputs ["is_block_step" ], self .share_inputs ['step_block_list' ],
363+ self .share_inputs ['step_lens' ], self .share_inputs ['recover_block_list' ],
364+ self .share_inputs ['recover_lens' ], self .share_inputs ['need_block_list' ],
365+ self .share_inputs ['need_block_len' ], self .share_inputs ['used_list_len' ],
366+ self .share_inputs ['free_list' ], self .share_inputs ['free_list_len' ],
367+ self .share_inputs ['input_ids' ], self .share_inputs ['pre_ids' ],
368+ self .share_inputs ['step_idx' ], self .share_inputs ['next_tokens' ],
369+ self .args .block_size , self .args .enc_dec_block_num , self .args .first_token_id )
370+ else :
371+ speculate_step_paddle (self .share_inputs ['stop_flags' ], seq_lens_this_time ,
372+ self .share_inputs ['step_seq_lens_encoder' ],
373+ self .share_inputs ['seq_lens_encoder' ],
374+ self .share_inputs ['seq_lens_decoder' ], self .share_inputs ["block_tables" ],
375+ self .share_inputs ['encoder_block_lens' ],
376+ self .share_inputs ["is_block_step" ], self .share_inputs ['step_block_list' ],
377+ self .share_inputs ['step_lens' ], self .share_inputs ['recover_block_list' ],
378+ self .share_inputs ['recover_lens' ], self .share_inputs ['need_block_list' ],
379+ self .share_inputs ['need_block_len' ], self .share_inputs ['used_list_len' ],
380+ self .share_inputs ['free_list' ], self .share_inputs ['free_list_len' ],
381+ self .share_inputs ['input_ids' ], self .share_inputs ['pre_ids' ],
382+ self .share_inputs ['step_idx' ], self .share_inputs ['next_tokens' ],
383+ self .args .block_size , self .args .enc_dec_block_num , self .args .first_token_id ,
384+ self .config .speculate_max_draft_token_num )
338385
339386 def initialize_engine_ready_check_flag (self ):
340387 """
@@ -434,6 +481,9 @@ def run(self):
434481 self .share_inputs ["seq_lens_this_time" ][:real_bsz ] = seq_lens_this_time
435482
436483 tasks , read_finish = self .infer_queue .get ()
484+ logger .info (f'tasks: { tasks } ' )
485+ logger .info (f'read_finish: { read_finish } ' )
486+
437487 if read_finish :
438488 flag_broadcast_array [0 ] = 0
439489
@@ -442,7 +492,7 @@ def run(self):
442492 real_bsz = int (bsz )
443493 req_dicts .extend (req_dict )
444494 logger .info (
445- f'rank: { self .rank } , real_bsz: { real_bsz } , query_num: { len (req_dicts )} '
495+ f'req_dict: { req_dict } rank: { self .rank } , real_bsz: { real_bsz } , query_num: { len (req_dicts )} '
446496 )
447497
448498 self .dy_input_preprocess (req_dicts )
@@ -459,10 +509,36 @@ def run(self):
459509 time .sleep (0.001 )
460510 continue
461511
512+ if self .proposer is not None :
513+ logger .info ("start run proposer" )
514+ logger .info (f'before draft_tokens: { self .share_inputs ["draft_tokens" ]} ' )
515+ logger .info (f'before accept_tokens: { self .share_inputs ["accept_tokens" ]} ' )
516+
517+ self .proposer .run (
518+ self .share_inputs ,
519+ real_batch_size = self .args .max_batch_size ,
520+ seq_lens_this_time = self .share_inputs ["seq_lens_this_time" ],
521+ )
522+ logger .info (f'after draft_tokens: { self .share_inputs ["draft_tokens" ]} ' )
523+ logger .info ("finish run proposer" )
524+ logger .info (f'input_ids: { self .share_inputs ["input_ids" ]} ' )
525+ logger .info (f'input_ids_cpu: { self .share_inputs ["input_ids_cpu" ]} ' )
526+ logger .info (f'seq_lens_this_time: { self .share_inputs ["seq_lens_this_time" ]} ' )
527+ logger .info (f'seq_lens_encoder: { self .share_inputs ["seq_lens_encoder" ]} ' )
528+ logger .info (f'seq_lens_decoder: { self .share_inputs ["seq_lens_decoder" ]} ' )
529+ logger .info (f'step_idx: { self .share_inputs ["step_idx" ]} ' )
530+ logger .info (f'next_tokens: { self .share_inputs ["next_tokens" ]} ' )
531+ logger .info (f'before block_tables: { self .share_inputs ["block_tables" ]} ' )
532+
462533 self .infer_engine .predictor .run ()
534+ logger .info (f'after accept_tokens: { self .share_inputs ["accept_tokens" ]} ' )
535+ logger .info (f'after accept_num: { self .share_inputs ["accept_num" ]} ' )
536+ logger .info (f'after block_tables: { self .share_inputs ["block_tables" ]} ' )
537+
463538 self .share_inputs ['infer_seed' ].add_ (infer_seed_increment )
464539 self .share_inputs ['infer_seed' ][:] %= self .MAX_INFER_SEED
465540 if self .free_list_len > 0 :
541+ logger .info (f'free_list_len > 0' )
466542 self .step_cuda (seq_lens_this_time )
467543
468544
0 commit comments