Skip to content

Commit d29f7fc

Browse files
authored
feat: conditional disagg based on prefill queue size (NVIDIA#303)
1 parent d716514 commit d29f7fc

6 files changed

Lines changed: 45 additions & 4 deletions

File tree

examples/llm/components/disagg_router.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,23 @@ def __init__(
2323
runtime,
2424
served_model_name,
2525
max_local_prefill_length=1000,
26+
max_prefill_queue_size=2,
2627
):
2728
self.runtime = runtime
2829
self.served_model_name = served_model_name
2930
self.max_local_prefill_length = max_local_prefill_length
31+
self.max_prefill_queue_size = max_prefill_queue_size
3032

31-
def prefill_remote(self, prompt_length: int, prefix_hit_rate: float):
33+
def prefill_remote(
34+
self, prompt_length: int, prefix_hit_rate: float, queue_size: int
35+
):
3236
absolute_prefill_length = int(prompt_length * (1 - prefix_hit_rate))
37+
# TODO: consider size of each request in the queue when making the decision
38+
decision = (
39+
absolute_prefill_length > self.max_local_prefill_length
40+
and queue_size < self.max_prefill_queue_size
41+
)
3342
vllm_logger.info(
34-
f"Remote prefill: {absolute_prefill_length > self.max_local_prefill_length} (prefill length: {absolute_prefill_length}/{prompt_length})"
43+
f"Remote prefill: {decision} (prefill length: {absolute_prefill_length}/{prompt_length}, prefill queue size: {queue_size}/{self.max_prefill_queue_size})"
3544
)
36-
return absolute_prefill_length > self.max_local_prefill_length
45+
return decision

examples/llm/components/worker.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ async def async_init(self):
125125
runtime,
126126
self.model_name,
127127
max_local_prefill_length=self.engine_args.max_local_prefill_length,
128+
max_prefill_queue_size=self.engine_args.max_prefill_queue_size,
128129
)
129130
else:
130131
self.disaggregated_router = None
@@ -148,9 +149,17 @@ async def callback(request: RemotePrefillRequest):
148149
@dynamo_endpoint()
149150
async def generate(self, request: vLLMGenerateRequest):
150151
# TODO: consider prefix hit when deciding prefill locally or remotely
152+
151153
if self.disaggregated_router is not None:
154+
async with PrefillQueue.get_instance(
155+
nats_server=self._prefill_queue_nats_server,
156+
stream_name=self._prefill_queue_stream_name,
157+
) as prefill_queue:
158+
prefill_queue_size = await prefill_queue.get_queue_size()
152159
disagg_router_decision = self.disaggregated_router.prefill_remote(
153-
len(request.engine_prompt["prompt_token_ids"]), request.prefix_hit_rate
160+
len(request.engine_prompt["prompt_token_ids"]),
161+
request.prefix_hit_rate,
162+
prefill_queue_size,
154163
)
155164
else:
156165
# always prefill remotely if no disaggregated router is provided

examples/llm/configs/disagg.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ VllmWorker:
3030
remote-prefill: true
3131
conditional-disagg: true
3232
max-local-prefill-length: 10
33+
max-prefill-queue-size: 2
3334
ServiceArgs:
3435
workers: 1
3536
resources:

examples/llm/configs/disagg_router.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ VllmWorker:
3636
max-model-len: 16384
3737
max-num-batched-tokens: 16384
3838
conditional-disagg: true
39+
max-local-prefill-length: 10
40+
max-prefill-queue-size: 2
3941
tensor-parallel-size: 1
4042
router: kv
4143
enable-prefix-caching: true

examples/llm/utils/nats_queue.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,3 +140,16 @@ async def dequeue_task(self) -> Optional[bytes]:
140140
return None
141141
except NatsError as e:
142142
raise RuntimeError(f"Failed to dequeue task: {e}")
143+
144+
async def get_queue_size(self) -> int:
145+
"""Get the number of messages currently in the queue"""
146+
await self.ensure_connection()
147+
try:
148+
# Get consumer info to get pending messages count
149+
consumer_info = await self._js.consumer_info( # type: ignore
150+
self._stream_name, "worker-group"
151+
)
152+
# Return number of pending messages (real-time queue size)
153+
return consumer_info.num_pending
154+
except NatsError as e:
155+
raise RuntimeError(f"Failed to get queue size: {e}")

examples/llm/utils/vllm.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,18 @@ def parse_vllm_args(service_name, prefix) -> AsyncEngineArgs:
4545
default=1000,
4646
help="Maximum length of local prefill",
4747
)
48+
parser.add_argument(
49+
"--max-prefill-queue-size",
50+
type=int,
51+
default=3,
52+
help="Do not send remote prefill requests (prefill locally) if the queue size is greater than this value",
53+
)
4854
parser = AsyncEngineArgs.add_cli_args(parser)
4955
args = parser.parse_args(vllm_args)
5056
engine_args = AsyncEngineArgs.from_cli_args(args)
5157
engine_args.router = args.router
5258
engine_args.remote_prefill = args.remote_prefill
5359
engine_args.conditional_disagg = args.conditional_disagg
5460
engine_args.max_local_prefill_length = args.max_local_prefill_length
61+
engine_args.max_prefill_queue_size = args.max_prefill_queue_size
5562
return engine_args

0 commit comments

Comments
 (0)