Skip to content

Commit a544d82

Browse files
authored
chore: more Pythonic kv router cleanups in examples (NVIDIA#396)
1 parent cce0c02 commit a544d82

1 file changed

Lines changed: 29 additions & 48 deletions

File tree

examples/llm/components/kv_router.py

Lines changed: 29 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ def __init__(self):
8383
vllm_logger.info("Initializing Custom Router")
8484
self.args = parse_args(self.__class__.__name__, "")
8585

86+
self.default_metrics = {
87+
"gpu_cache_usage_perc": 0.0,
88+
"num_requests_waiting": 0.0,
89+
"gpu_prefix_cache_hit_rate": 0.0,
90+
}
91+
8692
@async_on_start
8793
async def async_init(self):
8894
self.runtime = dynamo_context["runtime"]
@@ -140,21 +146,13 @@ def _cost_function(
140146
)
141147

142148
worker_metrics = {}
143-
# pull metrics for each worker
144149
max_waiting = 0.0
145150
if metrics:
146151
for endpoint in metrics.endpoints:
147152
worker_id = endpoint.worker_id
148153
worker_metrics[worker_id] = {
149-
"gpu_cache_usage_perc": getattr(
150-
endpoint, "gpu_cache_usage_perc", 0.0
151-
),
152-
"num_requests_waiting": getattr(
153-
endpoint, "num_requests_waiting", 0.0
154-
),
155-
"gpu_prefix_cache_hit_rate": getattr(
156-
endpoint, "gpu_prefix_cache_hit_rate", 0.0
157-
),
154+
key: getattr(endpoint, key, self.default_metrics[key])
155+
for key in self.default_metrics.keys()
158156
}
159157
max_waiting = max(
160158
max_waiting, worker_metrics[worker_id]["num_requests_waiting"]
@@ -168,14 +166,8 @@ def _cost_function(
168166
for worker_id in worker_ids:
169167
# Use default values if worker not in scores or metrics
170168
score = worker_scores.get(worker_id, 0.0)
171-
metrics_dict = worker_metrics.get(
172-
worker_id,
173-
{
174-
"gpu_cache_usage_perc": 0.0,
175-
"num_requests_waiting": 0.0,
176-
"gpu_prefix_cache_hit_rate": 0.0,
177-
},
178-
)
169+
metrics_dict = worker_metrics.get(worker_id, self.default_metrics)
170+
gpu_cache_usage = metrics_dict["gpu_cache_usage_perc"]
179171

180172
normalized_waiting = (
181173
metrics_dict["num_requests_waiting"] / max_waiting
@@ -185,15 +177,13 @@ def _cost_function(
185177

186178
# Have 1 metric that weights towards cache hit
187179
# 2 metrics that penalize overloaded worker and queuing
188-
worker_logits[worker_id] = (
189-
2 * score - metrics_dict["gpu_cache_usage_perc"] - normalized_waiting
190-
)
180+
worker_logits[worker_id] = 2 * score - gpu_cache_usage - normalized_waiting
191181
vllm_logger.info(
192-
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {metrics_dict['gpu_cache_usage_perc']:.3f} - {normalized_waiting:.3f}"
182+
f"Formula for {worker_id}: {worker_logits[worker_id]:.3f} = 2.0 * {score:.3f} - {gpu_cache_usage:.3f} - {normalized_waiting:.3f}"
193183
)
194184

195185
if not worker_logits or all(logit == 0 for logit in worker_logits.values()):
196-
return ""
186+
return "", 0.0
197187

198188
# Select the worker with the highest logit
199189
max_logit = max(worker_logits.values())
@@ -204,30 +194,26 @@ def _cost_function(
204194

205195
# Log the metrics for the selected worker
206196
if best_worker_id:
207-
vllm_logger.info(
208-
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}"
209-
)
210-
vllm_logger.info(
211-
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}"
212-
)
197+
metrics_dict = worker_metrics.get(best_worker_id, self.default_metrics)
213198

214-
metrics_dict = worker_metrics.get(best_worker_id, {})
215-
vllm_logger.info(
216-
f"GPU Cache Hit Rate: {metrics_dict.get('gpu_prefix_cache_hit_rate', 0.0):.3f}"
217-
)
218-
vllm_logger.info(
219-
f"GPU Cache Usage: {metrics_dict.get('gpu_cache_usage_perc', 0.0):.3f}"
220-
)
221-
vllm_logger.info(
222-
f"Requests Waiting: {metrics_dict.get('num_requests_waiting', 0.0) / max_waiting if max_waiting > 0 else 0.0:.3f}"
223-
)
199+
# Create log messages
200+
log_messages = [
201+
f"Selected worker: {best_worker_id}, logit: {worker_logits[best_worker_id]:.3f}",
202+
f"Score: {scores.scores.get(best_worker_id, 0.0) if scores else 0.0:.3f}",
203+
f"GPU Cache Hit Rate: {metrics_dict['gpu_prefix_cache_hit_rate']:.3f}",
204+
f"GPU Cache Usage: {metrics_dict['gpu_cache_usage_perc']:.3f}",
205+
f"Requests Waiting: {metrics_dict['num_requests_waiting']}",
206+
]
207+
208+
# Log to vllm_logger
209+
for message in log_messages:
210+
vllm_logger.info(message)
224211

225212
return best_worker_id, worker_scores.get(best_worker_id, 0.0)
226213

227214
@dynamo_endpoint()
228215
async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]:
229216
lora_id = 0
230-
worker_id = ""
231217
try:
232218
scores = await self.indexer.find_matches_for_request(
233219
request.tokens, lora_id
@@ -236,17 +222,12 @@ async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]:
236222
scores = {}
237223
vllm_logger.exception(f"Error finding matches: {e}")
238224

239-
token_length = len(request.tokens)
240225
metrics = await self.metrics_aggregator.get_metrics()
241-
schedule_result = self._cost_function(scores, metrics, token_length)
242-
if schedule_result == "":
243-
worker_id = ""
244-
prefix_hit_rate = 0.0
245-
else:
246-
worker_id, prefix_hit_rate = schedule_result
226+
worker_id, prefix_hit_rate = self._cost_function(
227+
scores, metrics, len(request.tokens)
228+
)
247229

248230
vllm_logger.info(
249231
f"Scheduling to worker_id: {worker_id} with estimated prefix hit rate: {prefix_hit_rate}"
250232
)
251-
252233
yield f"{worker_id}_{prefix_hit_rate}"

0 commit comments

Comments
 (0)