@@ -383,15 +383,18 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
383383
384384 req_len = len (req_dicts )
385385 has_prefill_task = False
386+ has_decode_task = False
386387 for i in range (req_len ):
387388 request = req_dicts [i ]
388389 idx = request .idx
389390 if request .task_type .value == RequestType .PREFILL .value : # prefill task
390- logger .debug (f"Handle prefill request { request } at idx { idx } " )
391391 prefill_start_index = request .prefill_start_index
392392 prefill_end_index = request .prefill_end_index
393393 length = prefill_end_index - prefill_start_index
394394 input_ids = request .prompt_token_ids + request .output_token_ids
395+ logger .debug (
396+ f"Handle prefill request { request } at idx { idx } prefill_start_index { prefill_start_index } prefill_end_index { prefill_end_index } need_prefilled_token_num { len (input_ids )} "
397+ )
395398 self .share_inputs ["input_ids" ][idx : idx + 1 , :length ] = np .array (
396399 input_ids [prefill_start_index :prefill_end_index ]
397400 )
@@ -420,6 +423,8 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
420423 self .share_inputs ["block_tables" ][idx : idx + 1 , :encoder_block_num ] = np .array (
421424 request .block_tables , dtype = "int32"
422425 )
426+ if self .share_inputs ["is_block_step" ][idx ]: # has tasks to continue to decode
427+ has_decode_task = True
423428 continue
424429 else : # preempted task
425430 logger .debug (f"Handle preempted request { request } at idx { idx } " )
@@ -460,7 +465,7 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
460465 self .share_inputs ["stop_seqs" ][:stop_seqs_num , : len (request .get ("stop_token_ids" )[0 ])] = np .array (
461466 request .get ("stop_token_ids" ), dtype = "int64"
462467 )
463- if has_prefill_task :
468+ if has_prefill_task or has_decode_task :
464469 self .share_inputs ["not_need_stop" ][0 ] = True
465470
466471 def process_prefill_inputs (self , req_dicts : List [Request ]):
0 commit comments