|
43 | 43 | from fastdeploy.cache_manager.cache_data import CacheStatus |
44 | 44 | from fastdeploy.config import FDConfig |
45 | 45 | from fastdeploy.engine.request import ( |
| 46 | + CompletionOutput, |
46 | 47 | ControlRequest, |
47 | 48 | ControlResponse, |
48 | 49 | Request, |
| 50 | + RequestMetrics, |
49 | 51 | RequestOutput, |
50 | 52 | RequestStatus, |
51 | 53 | RequestType, |
@@ -1413,6 +1415,139 @@ def _control_update_weights(self, control_request: ControlRequest) -> Optional[d |
1413 | 1415 | raise Exception(error_msg) |
1414 | 1416 | return self._call_worker(control_request, 60) |
1415 | 1417 |
|
| 1418 | + def _control_abort_requests(self, control_req: ControlRequest): |
| 1419 | + if not envs.ENABLE_V1_KVCACHE_SCHEDULER: |
| 1420 | + raise Exception("abort_requests only supported in ENABLE_V1_KVCACHE_SCHEDULER") |
| 1421 | + args = control_req.get_args() |
| 1422 | + abort_all = args.get("abort_all", False) |
| 1423 | + req_ids = args.get("req_ids", []) |
| 1424 | + matched_input_ids = set() |
| 1425 | + now_reqs = list(set(self.resource_manager.requests.keys()) | set(self.scheduler.requests.keys())) |
| 1426 | + |
| 1427 | + # Step 1: Determine target request list |
| 1428 | + if abort_all: |
| 1429 | + # all requests in running + waiting |
| 1430 | + target_req_ids = now_reqs |
| 1431 | + else: |
| 1432 | + # filter out requests that actually exist |
| 1433 | + target_req_ids = [] |
| 1434 | + for rid in req_ids: |
| 1435 | + if rid in now_reqs: |
| 1436 | + target_req_ids.append(rid) |
| 1437 | + matched_input_ids.add(rid) |
| 1438 | + elif f"{rid}_0" in now_reqs: |
| 1439 | + target_req_ids.append(f"{rid}_0") |
| 1440 | + matched_input_ids.add(rid) |
| 1441 | + |
| 1442 | + if not target_req_ids: |
| 1443 | + return {"aborted": [], "not_found": req_ids if not abort_all else []} |
| 1444 | + |
| 1445 | + # Step 2: Collect partial results |
| 1446 | + aborted_info = [] |
| 1447 | + results = [] |
| 1448 | + for req_id in target_req_ids: |
| 1449 | + request = self.resource_manager.requests.get(req_id) |
| 1450 | + if request is None: |
| 1451 | + scheduled_req = self.scheduler.requests.get(req_id) |
| 1452 | + if scheduled_req is None: |
| 1453 | + continue |
| 1454 | + request = scheduled_req.raw |
| 1455 | + |
| 1456 | + partial_token_ids = list(request.output_token_ids) |
| 1457 | + |
| 1458 | + # Construct finished response with partial results |
| 1459 | + now = time.time() |
| 1460 | + abort_metrics = RequestMetrics( |
| 1461 | + arrival_time=request.metrics.arrival_time if request.metrics else now, |
| 1462 | + inference_start_time=request.metrics.inference_start_time if request.metrics else now, |
| 1463 | + engine_recv_latest_token_time=now, |
| 1464 | + engine_recv_first_token_time=request.metrics.engine_recv_first_token_time if request.metrics else now, |
| 1465 | + request_start_time=request.metrics.arrival_time if request.metrics else now, |
| 1466 | + ) |
| 1467 | + result = RequestOutput( |
| 1468 | + request_id=req_id, |
| 1469 | + finished=True, |
| 1470 | + outputs=CompletionOutput( |
| 1471 | + index=0, |
| 1472 | + send_idx=len(partial_token_ids), |
| 1473 | + token_ids=[self.data_processor.eos_token_ids[0]], |
| 1474 | + ), |
| 1475 | + metrics=abort_metrics, |
| 1476 | + error_code=200, |
| 1477 | + error_msg="Aborted", |
| 1478 | + ) |
| 1479 | + results.append(result) |
| 1480 | + aborted_info.append( |
| 1481 | + { |
| 1482 | + "request_id": req_id, |
| 1483 | + "output_token_count": len(partial_token_ids), |
| 1484 | + } |
| 1485 | + ) |
| 1486 | + |
| 1487 | + # Step 3: Execute abort — add all requests to waiting_abort_req_id_set |
| 1488 | + if envs.ENABLE_V1_KVCACHE_SCHEDULER: |
| 1489 | + for req_id in target_req_ids: |
| 1490 | + self.resource_manager.add_abort_req_ids(req_id) |
| 1491 | + time.sleep(0.0001) |
| 1492 | + if self.cfg.scheduler_config.splitwise_role != "prefill": |
| 1493 | + self._wait_abort_complete(target_req_ids) |
| 1494 | + |
| 1495 | + # Add results to scheduler, engine will have a thread calling get_results, |
| 1496 | + # then cleanup and call send_response to send to client. |
| 1497 | + # When client disconnects, send_response will automatically ignore |
| 1498 | + if self.cfg.scheduler_config.splitwise_role != "prefill": |
| 1499 | + try: |
| 1500 | + # self.send_response_server.send_response(req_id, [result]) |
| 1501 | + self.scheduler.put_results(results) |
| 1502 | + except Exception: |
| 1503 | + pass # client may have disconnected |
| 1504 | + |
| 1505 | + not_found = [rid for rid in req_ids if rid not in matched_input_ids] if not abort_all else [] |
| 1506 | + |
| 1507 | + return {"aborted": aborted_info, "not_found": not_found} |
| 1508 | + |
| 1509 | + def _wait_abort_complete(self, target_req_ids, stall_timeout=1): |
| 1510 | + """ |
| 1511 | + Wait for all abort requests to complete. |
| 1512 | + - Keep monitoring as long as remaining is not empty, which means cleanup is not done yet |
| 1513 | + - If no progress within stall_timeout seconds, force cleanup requests stuck in to_be_aborted_req_id_set, |
| 1514 | + reset progress state if any, then continue monitoring |
| 1515 | + """ |
| 1516 | + target_set = set(target_req_ids) |
| 1517 | + prev_remaining_count = len(target_set) |
| 1518 | + last_progress_time = time.time() |
| 1519 | + remaining = target_set & self.resource_manager.get_reqs_in_aborting() |
| 1520 | + while remaining: |
| 1521 | + remaining = target_set & self.resource_manager.get_reqs_in_aborting() |
| 1522 | + if not remaining: |
| 1523 | + self.llm_logger.info(f"all {len(target_set)} abort reqs cleaned") |
| 1524 | + return |
| 1525 | + |
| 1526 | + current_count = len(remaining) |
| 1527 | + if current_count < prev_remaining_count: |
| 1528 | + # progress made: recycle_abort_task was called |
| 1529 | + self.llm_logger.info(f"abort progress: {prev_remaining_count} -> {current_count}") |
| 1530 | + last_progress_time = time.time() |
| 1531 | + prev_remaining_count = current_count |
| 1532 | + |
| 1533 | + if time.time() - last_progress_time > stall_timeout: |
| 1534 | + # no progress timeout: only cleanup requests stuck in to_be_aborted (worker hasn't returned -9) |
| 1535 | + stuck = remaining & self.resource_manager.to_be_aborted_req_id_set |
| 1536 | + if stuck: |
| 1537 | + self.llm_logger.warning( |
| 1538 | + f"no abort progress for {stall_timeout}s, " |
| 1539 | + f"force cleanup {len(stuck)} stuck requests (in to_be_aborted)" |
| 1540 | + ) |
| 1541 | + for req_id in list(stuck): |
| 1542 | + self.llm_logger.warning(f"force cleanup stuck req_id:{req_id}") |
| 1543 | + self.resource_manager.recycle_abort_task(req_id) |
| 1544 | + # reset progress state |
| 1545 | + last_progress_time = time.time() |
| 1546 | + prev_remaining_count = current_count - len(stuck) |
| 1547 | + # else: remaining are all in waiting_abort_req_id_set, waiting for natural flow |
| 1548 | + |
| 1549 | + time.sleep(0.005) |
| 1550 | + |
1416 | 1551 | def _parse_tags(self, control_request: ControlRequest): |
1417 | 1552 | """ |
1418 | 1553 | Parse tags from control request. |
|
0 commit comments