Skip to content

Commit 1016058

Browse files
authored
[XPU] Fixed the issue of performance degradation caused by enabling ENABLE_V1_KVCACHE_SCHEDULER (#3393)
* fix v1 schedule oom bug * fix v1 schedule oom bug
1 parent 2891870 commit 1016058

4 files changed

Lines changed: 22 additions & 9 deletions

File tree

fastdeploy/engine/args_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
"""
1616

1717
import json
18+
import os
1819
from dataclasses import asdict, dataclass
1920
from dataclasses import fields as dataclass_fields
2021
from typing import Any, Dict, List, Optional
21-
import os
22+
23+
import paddle
2224

2325
from fastdeploy.config import (
2426
CacheConfig,
@@ -866,10 +868,13 @@ def create_engine_config(self) -> Config:
866868
if self.enable_chunked_prefill:
867869
self.max_num_batched_tokens = 2048
868870
else:
869-
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
871+
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
870872
self.max_num_batched_tokens = self.max_model_len
871873
else:
872-
self.max_num_batched_tokens = 8192
874+
if paddle.is_compiled_with_xpu():
875+
self.max_num_batched_tokens = self.max_model_len
876+
else:
877+
self.max_num_batched_tokens = 8192
873878

874879
all_dict = asdict(self)
875880
all_dict["model_cfg"] = model_cfg

fastdeploy/engine/config.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,10 +236,13 @@ def postprocess(self):
236236
if self.cache_config.enable_chunked_prefill:
237237
self.max_num_batched_tokens = 2048
238238
else:
239-
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
239+
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
240240
self.max_num_batched_tokens = self.max_model_len
241241
else:
242-
self.max_num_batched_tokens = 8192
242+
if paddle.is_compiled_with_xpu():
243+
self.max_num_batched_tokens = self.max_model_len
244+
else:
245+
self.max_num_batched_tokens = 8192
243246

244247
if self.long_prefill_token_threshold == 0:
245248
self.long_prefill_token_threshold = int(self.max_model_len * 0.04)
@@ -287,7 +290,7 @@ def check(self):
287290
)
288291

289292
if not self.cache_config.enable_chunked_prefill:
290-
if not int(os.getenv('ENABLE_V1_KVCACHE_SCHEDULER', '0')):
293+
if not int(os.getenv("ENABLE_V1_KVCACHE_SCHEDULER", "0")):
291294
assert self.max_num_batched_tokens >= self.max_model_len, (
292295
f"max_num_batched_tokens: {self.max_num_batched_tokens} "
293296
f"should be larger than or equal to max_model_len: {self.max_model_len}"

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def schedule(self):
289289
while self.waiting and token_budget > 0:
290290
if len(self.running) == self.max_num_seqs:
291291
break
292-
if self.config.enable_mm and self.exist_prefill(scheduled_reqs):
292+
if (self.config.enable_mm or paddle.is_compiled_with_xpu()) and self.exist_prefill(scheduled_reqs):
293293
break
294294
request = self.waiting[0]
295295
if request.status == RequestStatus.WAITING:

fastdeploy/worker/xpu_model_runner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,15 +383,18 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
383383

384384
req_len = len(req_dicts)
385385
has_prefill_task = False
386+
has_decode_task = False
386387
for i in range(req_len):
387388
request = req_dicts[i]
388389
idx = request.idx
389390
if request.task_type.value == RequestType.PREFILL.value: # prefill task
390-
logger.debug(f"Handle prefill request {request} at idx {idx}")
391391
prefill_start_index = request.prefill_start_index
392392
prefill_end_index = request.prefill_end_index
393393
length = prefill_end_index - prefill_start_index
394394
input_ids = request.prompt_token_ids + request.output_token_ids
395+
logger.debug(
396+
f"Handle prefill request {request} at idx {idx} prefill_start_index {prefill_start_index} prefill_end_index {prefill_end_index} need_prefilled_token_num {len(input_ids)}"
397+
)
395398
self.share_inputs["input_ids"][idx : idx + 1, :length] = np.array(
396399
input_ids[prefill_start_index:prefill_end_index]
397400
)
@@ -420,6 +423,8 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
420423
self.share_inputs["block_tables"][idx : idx + 1, :encoder_block_num] = np.array(
421424
request.block_tables, dtype="int32"
422425
)
426+
if self.share_inputs["is_block_step"][idx]: # has tasks to continue to decode
427+
has_decode_task = True
423428
continue
424429
else: # preempted task
425430
logger.debug(f"Handle preempted request {request} at idx {idx}")
@@ -460,7 +465,7 @@ def insert_tasks_v1(self, req_dicts: List[Request]):
460465
self.share_inputs["stop_seqs"][:stop_seqs_num, : len(request.get("stop_token_ids")[0])] = np.array(
461466
request.get("stop_token_ids"), dtype="int64"
462467
)
463-
if has_prefill_task:
468+
if has_prefill_task or has_decode_task:
464469
self.share_inputs["not_need_stop"][0] = True
465470

466471
def process_prefill_inputs(self, req_dicts: List[Request]):

0 commit comments

Comments
 (0)