@@ -83,6 +83,12 @@ def __init__(self):
8383 vllm_logger .info ("Initializing Custom Router" )
8484 self .args = parse_args (self .__class__ .__name__ , "" )
8585
86+ self .default_metrics = {
87+ "gpu_cache_usage_perc" : 0.0 ,
88+ "num_requests_waiting" : 0.0 ,
89+ "gpu_prefix_cache_hit_rate" : 0.0 ,
90+ }
91+
8692 @async_on_start
8793 async def async_init (self ):
8894 self .runtime = dynamo_context ["runtime" ]
@@ -140,21 +146,13 @@ def _cost_function(
140146 )
141147
142148 worker_metrics = {}
143- # pull metrics for each worker
144149 max_waiting = 0.0
145150 if metrics :
146151 for endpoint in metrics .endpoints :
147152 worker_id = endpoint .worker_id
148153 worker_metrics [worker_id ] = {
149- "gpu_cache_usage_perc" : getattr (
150- endpoint , "gpu_cache_usage_perc" , 0.0
151- ),
152- "num_requests_waiting" : getattr (
153- endpoint , "num_requests_waiting" , 0.0
154- ),
155- "gpu_prefix_cache_hit_rate" : getattr (
156- endpoint , "gpu_prefix_cache_hit_rate" , 0.0
157- ),
154+ key : getattr (endpoint , key , self .default_metrics [key ])
155+ for key in self .default_metrics .keys ()
158156 }
159157 max_waiting = max (
160158 max_waiting , worker_metrics [worker_id ]["num_requests_waiting" ]
@@ -168,14 +166,8 @@ def _cost_function(
168166 for worker_id in worker_ids :
169167 # Use default values if worker not in scores or metrics
170168 score = worker_scores .get (worker_id , 0.0 )
171- metrics_dict = worker_metrics .get (
172- worker_id ,
173- {
174- "gpu_cache_usage_perc" : 0.0 ,
175- "num_requests_waiting" : 0.0 ,
176- "gpu_prefix_cache_hit_rate" : 0.0 ,
177- },
178- )
169+ metrics_dict = worker_metrics .get (worker_id , self .default_metrics )
170+ gpu_cache_usage = metrics_dict ["gpu_cache_usage_perc" ]
179171
180172 normalized_waiting = (
181173 metrics_dict ["num_requests_waiting" ] / max_waiting
@@ -185,15 +177,13 @@ def _cost_function(
185177
186178 # Have 1 metric that weights towards cache hit
187179 # 2 metrics that penalize overloaded worker and queuing
188- worker_logits [worker_id ] = (
189- 2 * score - metrics_dict ["gpu_cache_usage_perc" ] - normalized_waiting
190- )
180+ worker_logits [worker_id ] = 2 * score - gpu_cache_usage - normalized_waiting
191181 vllm_logger .info (
192- f"Formula for { worker_id } : { worker_logits [worker_id ]:.3f} = 2.0 * { score :.3f} - { metrics_dict [ 'gpu_cache_usage_perc' ] :.3f} - { normalized_waiting :.3f} "
182+ f"Formula for { worker_id } : { worker_logits [worker_id ]:.3f} = 2.0 * { score :.3f} - { gpu_cache_usage :.3f} - { normalized_waiting :.3f} "
193183 )
194184
195185 if not worker_logits or all (logit == 0 for logit in worker_logits .values ()):
196- return ""
186+ return "" , 0.0
197187
198188 # Select the worker with the highest logit
199189 max_logit = max (worker_logits .values ())
@@ -204,30 +194,26 @@ def _cost_function(
204194
205195 # Log the metrics for the selected worker
206196 if best_worker_id :
207- vllm_logger .info (
208- f"Selected worker: { best_worker_id } , logit: { worker_logits [best_worker_id ]:.3f} "
209- )
210- vllm_logger .info (
211- f"Score: { scores .scores .get (best_worker_id , 0.0 ) if scores else 0.0 :.3f} "
212- )
197+ metrics_dict = worker_metrics .get (best_worker_id , self .default_metrics )
213198
214- metrics_dict = worker_metrics .get (best_worker_id , {})
215- vllm_logger .info (
216- f"GPU Cache Hit Rate: { metrics_dict .get ('gpu_prefix_cache_hit_rate' , 0.0 ):.3f} "
217- )
218- vllm_logger .info (
219- f"GPU Cache Usage: { metrics_dict .get ('gpu_cache_usage_perc' , 0.0 ):.3f} "
220- )
221- vllm_logger .info (
222- f"Requests Waiting: { metrics_dict .get ('num_requests_waiting' , 0.0 ) / max_waiting if max_waiting > 0 else 0.0 :.3f} "
223- )
199+ # Create log messages
200+ log_messages = [
201+ f"Selected worker: { best_worker_id } , logit: { worker_logits [best_worker_id ]:.3f} " ,
202+ f"Score: { scores .scores .get (best_worker_id , 0.0 ) if scores else 0.0 :.3f} " ,
203+ f"GPU Cache Hit Rate: { metrics_dict ['gpu_prefix_cache_hit_rate' ]:.3f} " ,
204+ f"GPU Cache Usage: { metrics_dict ['gpu_cache_usage_perc' ]:.3f} " ,
205+ f"Requests Waiting: { metrics_dict ['num_requests_waiting' ]} " ,
206+ ]
207+
208+ # Log to vllm_logger
209+ for message in log_messages :
210+ vllm_logger .info (message )
224211
225212 return best_worker_id , worker_scores .get (best_worker_id , 0.0 )
226213
227214 @dynamo_endpoint ()
228215 async def generate (self , request : Tokens ) -> AsyncIterator [WorkerId ]:
229216 lora_id = 0
230- worker_id = ""
231217 try :
232218 scores = await self .indexer .find_matches_for_request (
233219 request .tokens , lora_id
@@ -236,17 +222,12 @@ async def generate(self, request: Tokens) -> AsyncIterator[WorkerId]:
236222 scores = {}
237223 vllm_logger .exception (f"Error finding matches: { e } " )
238224
239- token_length = len (request .tokens )
240225 metrics = await self .metrics_aggregator .get_metrics ()
241- schedule_result = self ._cost_function (scores , metrics , token_length )
242- if schedule_result == "" :
243- worker_id = ""
244- prefix_hit_rate = 0.0
245- else :
246- worker_id , prefix_hit_rate = schedule_result
226+ worker_id , prefix_hit_rate = self ._cost_function (
227+ scores , metrics , len (request .tokens )
228+ )
247229
248230 vllm_logger .info (
249231 f"Scheduling to worker_id: { worker_id } with estimated prefix hit rate: { prefix_hit_rate } "
250232 )
251-
252233 yield f"{ worker_id } _{ prefix_hit_rate } "
0 commit comments