@@ -383,6 +383,7 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
383383
384384 req_len = len (req_dicts )
385385 has_prefill_task = False
386+ has_decode_task = False
386387 for i in range (req_len ):
387388 request = req_dicts [i ]
388389 idx = request .idx
@@ -392,6 +393,9 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
392393 prefill_end_index = request .prefill_end_index
393394 length = prefill_end_index - prefill_start_index
394395 input_ids = request .prompt_token_ids + request .output_token_ids
396+ logger .debug (
397+ f"Handle prefill request { request } at idx { idx } prefill_start_index { prefill_start_index } prefill_end_index { prefill_end_index } need_prefilled_token_num { len (input_ids )} "
398+ )
395399 self .share_inputs ["input_ids" ][idx : idx + 1 , :length ] = np .array (
396400 input_ids [prefill_start_index :prefill_end_index ]
397401 )
@@ -401,6 +405,8 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
401405 self .share_inputs ["block_tables" ][idx : idx + 1 , :encoder_block_num ] = np .array (
402406 request .block_tables , dtype = "int32"
403407 )
408+ if self .share_inputs ["is_block_step" ][idx ]: # has tasks to continue to decode
409+ has_decode_task = True
404410 self .share_inputs ["stop_flags" ][idx : idx + 1 ] = False
405411 self .share_inputs ["seq_lens_decoder" ][idx : idx + 1 ] = prefill_start_index
406412 self .share_inputs ["seq_lens_this_time" ][idx : idx + 1 ] = length
@@ -474,7 +480,7 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
474480 self .share_inputs ["stop_seqs" ][:stop_seqs_num , : len (request .get ("stop_token_ids" )[0 ])] = np .array (
475481 request .get ("stop_token_ids" ), dtype = "int64"
476482 )
477- if has_prefill_task :
483+ if has_prefill_task or has_decode_task :
478484 self .share_inputs ["not_need_stop" ][0 ] = True
479485
480486 def process_prefill_inputs (self , req_dicts : List [Request ]):
0 commit comments