@@ -84,7 +84,6 @@ def allocated_slots(self, request: Request):
8484 return len (request .block_tables ) * self .config .cache_config .block_size
8585
8686 def get_new_block_nums (self , request : Request , num_new_tokens : int ):
87- self .check_and_free_block_tables ()
8887 return (
8988 request .num_computed_tokens + num_new_tokens + self .config .cache_config .block_size - 1
9089 ) // self .config .cache_config .block_size - len (request .block_tables )
@@ -119,7 +118,7 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
119118 preempted_req .status = RequestStatus .PREEMPTED
120119 preempted_req .num_computed_tokens = 0
121120 self ._free_blocks (preempted_req )
122- preempted_req .prefill_block_num = None
121+ preempted_req .cached_block_num = 0
123122 self .to_be_rescheduled_request_id_set .add (preempted_req .request_id )
124123 preempted_reqs .append (preempted_req )
125124 scheduled_reqs .append (self ._prepare_preempt_task (preempted_req ))
@@ -282,14 +281,6 @@ def schedule(self):
282281 if request .num_computed_tokens >= request .need_prefill_tokens : # to be decoding
283282 if request .num_total_tokens > request .need_prefill_tokens : # has generated tokens
284283 request .num_computed_tokens = request .num_total_tokens - 1
285- else : # prefill finished
286- if (
287- self .config .cache_config .enable_prefix_caching
288- and request .get ("prefill_block_num" , None ) is None
289- ):
290- # update prefill cache blocks for prefix caching
291- request .prefill_block_num = len (request .block_tables )
292- self .cache_manager .update_cache_blocks (request , self .config .cache_config .block_size )
293284 if (
294285 self .allocated_slots (request ) - request .num_total_tokens
295286 <= self .config .cache_config .prealloc_dec_block_slot_num_threshold
@@ -339,6 +330,10 @@ def schedule(self):
339330 scheduled_reqs .append (self ._prepare_prefill_task (request , num_new_tokens ))
340331 token_budget -= num_new_tokens
341332 request .num_computed_tokens += num_new_tokens
333+ if self .config .cache_config .enable_prefix_caching :
334+ self .cache_manager .update_cache_blocks (
335+ request , self .config .cache_config .block_size , request .num_computed_tokens
336+ )
342337 req_index += 1
343338 # schedule the WAITING requests.
344339 if not preempted_reqs :
@@ -371,6 +366,10 @@ def schedule(self):
371366 request .schedule_start_time = time .time ()
372367 token_budget -= num_new_tokens
373368 request .num_computed_tokens += num_new_tokens
369+ if self .config .cache_config .enable_prefix_caching :
370+ self .cache_manager .update_cache_blocks (
371+ request , self .config .cache_config .block_size , request .num_computed_tokens
372+ )
374373 request .status = RequestStatus .RUNNING
375374 main_process_metrics .num_requests_waiting .dec (1 )
376375 main_process_metrics .num_requests_running .inc (1 )
@@ -403,6 +402,10 @@ def schedule(self):
403402 scheduled_reqs .append (self ._prepare_prefill_task (request , num_new_tokens ))
404403 token_budget -= num_new_tokens
405404 request .num_computed_tokens += num_new_tokens
405+ if self .config .cache_config .enable_prefix_caching :
406+ self .cache_manager .update_cache_blocks (
407+ request , self .config .cache_config .block_size , request .num_computed_tokens
408+ )
406409 request .status = RequestStatus .RUNNING
407410 main_process_metrics .num_requests_waiting .dec (1 )
408411 main_process_metrics .num_requests_running .inc (1 )
@@ -447,7 +450,7 @@ def get_prefix_cached_blocks(self, request: Request):
447450
448451 matched_block_num = len (common_block_ids )
449452 no_cache_block_num = self .cache_manager .get_required_block_num (
450- request .prompt_token_ids_len - matched_token_num ,
453+ request .need_prefill_tokens - matched_token_num ,
451454 self .config .cache_config .block_size ,
452455 )
453456
@@ -463,7 +466,7 @@ def get_prefix_cached_blocks(self, request: Request):
463466 main_process_metrics .prefix_gpu_cache_token_num .inc (request .gpu_cache_token_num )
464467 main_process_metrics .prefix_cpu_cache_token_num .inc (request .cpu_cache_token_num )
465468
466- if matched_token_num == request .prompt_token_ids_len :
469+ if matched_token_num == request .need_prefill_tokens :
467470 request .num_computed_tokens = matched_token_num - self .config .cache_config .block_size
468471 request .skip_allocate = True
469472 else :
@@ -481,16 +484,8 @@ def add_request(self, request: Request) -> None:
481484
482485 def _free_blocks (self , request : Request ):
483486 if self .config .cache_config .enable_prefix_caching :
484- # TODO(chengyanfu): support cache ouput blocks for prefix caching
485- if request .get ("prefill_block_num" , None ) is None :
486- leaf_node = self .cache_manager .req_leaf_map [request .request_id ]
487- self .cache_manager .decrease_request_share_count (request .request_id )
488- self .cache_manager .free_nodes_directly (leaf_node )
489- self .cache_manager .recycle_gpu_blocks (request .block_tables [request .cache_info [0 ] :])
490-
491- else :
492- self .cache_manager .release_block_ids_async (request )
493- self .cache_manager .recycle_gpu_blocks (request .block_tables [request .prefill_block_num :])
487+ self .cache_manager .release_block_ids (request )
488+ self .cache_manager .recycle_gpu_blocks (request .block_tables [request .cached_block_num :])
494489 else :
495490 self .cache_manager .recycle_gpu_blocks (request .block_tables )
496491 request .block_tables = []
0 commit comments