Skip to content

Commit 7ccbcc5

Browse files
[feat] support prefix cache clearing when /clear_load_weight is called (#4091)
* [feat] support clearing prefix cache (cherry-picked from release/2.1) * [fix] fix ipc suffix, use port instead * [fix] fix prefix caching not enabled * [fix] fix code style * [fix] wait for rank0 to update weight status
1 parent fbb4e0f commit 7ccbcc5

17 files changed

Lines changed: 624 additions & 181 deletions
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "helper.h"
16+
#include "cuda_multiprocess.h"
17+
18+
#if !defined(_WIN32)
19+
#include <errno.h>
20+
#include <string.h>
21+
#include <fcntl.h>
22+
#include <sys/mman.h>
23+
#include <sys/stat.h>
24+
#endif
25+
26+
// 可选:仅删除/解除共享内存命名对象(不依赖之前保存的 addr/fd)
27+
static inline int sharedMemoryUnlinkByName(const char* name) {
28+
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
29+
// Windows 上没有 shm_unlink 语义。命名对象在最后一个句柄关闭后消失。
30+
// 这里做“尽力而为”:尝试打开后立即关闭,减少一次引用。
31+
HANDLE hMap = OpenFileMappingA(FILE_MAP_ALL_ACCESS, FALSE, name);
32+
if (hMap) {
33+
CloseHandle(hMap);
34+
return 0;
35+
}
36+
// 已经不存在也算成功
37+
return 0;
38+
#else
39+
// POSIX: 移除名字,未来不可再 open;已映射区仍存活直至 munmap
40+
if (shm_unlink(name) != 0) {
41+
if (errno == ENOENT) return 0; // 不存在视作成功
42+
return errno;
43+
}
44+
return 0;
45+
#endif
46+
}
47+
48+
void UnsetDataIpc(const paddle::Tensor& tmp_input,
49+
const std::string& shm_name,
50+
bool close_ipc,
51+
bool unlink_shm) {
52+
// 1) 关闭消费者导入的 IPC 映射(仅当 close_ipc=true 且该指针确为 OpenMemHandle 得来)
53+
if (close_ipc) {
54+
void* ptr = const_cast<void*>(tmp_input.data());
55+
checkCudaErrors(cudaIpcCloseMemHandle(ptr));
56+
}
57+
58+
// 2) 解除共享内存命名对象(仅处理“名字”,不保证解除旧映射)
59+
if (unlink_shm) {
60+
int rc = sharedMemoryUnlinkByName(shm_name.c_str());
61+
if (rc != 0) {
62+
PD_THROW("Unlink shared memory failed: name=%s, err=%d",
63+
shm_name.c_str(), rc);
64+
}
65+
}
66+
}
67+
68+
PD_BUILD_STATIC_OP(unset_data_ipc)
69+
.Inputs({"tmp_input"})
70+
.Attrs({"shm_name: std::string", "close_ipc: bool", "unlink_shm: bool"})
71+
.SetKernelFn(PD_KERNEL(UnsetDataIpc));

custom_ops/setup_ops.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def find_end_files(directory, end_str):
208208
"gpu_ops/rebuild_padding.cu",
209209
"gpu_ops/step.cu",
210210
"gpu_ops/set_data_ipc.cu",
211+
"gpu_ops/unset_data_ipc.cu",
211212
"gpu_ops/moe/tritonmoe_preprocess.cu",
212213
"gpu_ops/step_system_cache.cu",
213214
"gpu_ops/get_output_ep.cc",
@@ -278,6 +279,7 @@ def find_end_files(directory, end_str):
278279
"gpu_ops/beam_search_softmax.cu",
279280
"gpu_ops/rebuild_padding.cu",
280281
"gpu_ops/set_data_ipc.cu",
282+
"gpu_ops/unset_data_ipc.cu",
281283
"gpu_ops/read_data_ipc.cu",
282284
"gpu_ops/enforce_generation.cu",
283285
"gpu_ops/dequant_int8.cu",

fastdeploy/cache_manager/cache_messager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def __init__(
9898
cache_v = []
9999
self.messager = {}
100100
for layer_idx in range(self.num_layers):
101-
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
102-
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}_device{gpu_id}"]
101+
key_cache = self.gpu_cache_kvs[f"key_caches_{layer_idx}_rank{self.rank}.device{gpu_id}"]
102+
val_cache = self.gpu_cache_kvs[f"value_caches_{layer_idx}_rank{self.rank}.device{gpu_id}"]
103103
cache_k.append(key_cache)
104104
cache_v.append(val_cache)
105105
cache_k_ptr_list.append(key_cache.data_ptr())

fastdeploy/cache_manager/cache_transfer_manager.py

Lines changed: 150 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,27 @@
1616

1717
import argparse
1818
import concurrent.futures
19+
import gc
1920
import json
2021
import queue
22+
import threading
2123
import time
2224
import traceback
2325

2426
import numpy as np
2527
import paddle
2628

29+
from fastdeploy import envs
2730
from fastdeploy.cache_manager.cache_data import CacheStatus
2831
from fastdeploy.config import SpeculativeConfig
29-
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal
32+
from fastdeploy.inter_communicator import EngineCacheQueue, IPCSignal, KVCacheStatus
3033
from fastdeploy.model_executor.ops.gpu import (
3134
cuda_host_alloc,
35+
cuda_host_free,
3236
set_data_ipc,
37+
share_external_data,
3338
swap_cache_all_layers,
39+
unset_data_ipc,
3440
)
3541
from fastdeploy.utils import get_logger
3642

@@ -93,6 +99,7 @@ def parse_args():
9399
help="speculative config",
94100
)
95101
parser.add_argument("--local_data_parallel_id", type=int, default=0)
102+
parser.add_argument("--create_cache_tensor", action="store_true")
96103

97104
args = parser.parse_args()
98105
return args
@@ -110,7 +117,6 @@ def __init__(self, args):
110117

111118
device = args.device_id
112119
rank = args.rank
113-
paddle.set_device(f"gpu:{device}")
114120
self.gpu_cache_kvs = {}
115121
self.cpu_cache_kvs = {}
116122
self.gpu_cache_k_tensors = []
@@ -126,6 +132,7 @@ def __init__(self, args):
126132
self.n_ranks = args.mp_num
127133
self.rank = rank
128134
self.device = device
135+
self.engine_pid = args.engine_pid
129136

130137
address = (args.pod_ip, args.cache_queue_port)
131138
self.cache_task_queue = EngineCacheQueue(
@@ -136,70 +143,27 @@ def __init__(self, args):
136143
local_data_parallel_id=args.local_data_parallel_id,
137144
)
138145

139-
self.num_cpu_blocks = args.num_cpu_blocks
140-
141-
cache_type = args.cache_dtype
142-
for i in range(args.num_layers + self.num_extra_layers):
143-
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
144-
145-
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"] = paddle.full(
146-
shape=[
147-
num_gpu_blocks,
148-
args.kv_num_head,
149-
args.block_size,
150-
args.head_dim,
151-
],
152-
fill_value=0,
153-
dtype=cache_type,
154-
)
155-
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"])
156-
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"] = paddle.full(
157-
shape=[
158-
num_gpu_blocks,
159-
args.kv_num_head,
160-
args.block_size,
161-
args.head_dim,
162-
],
163-
fill_value=0,
164-
dtype=cache_type,
165-
)
166-
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"])
167-
168-
set_data_ipc(
169-
self.gpu_cache_kvs[f"key_caches_{i}_rank{rank}_device{device}"],
170-
f"key_caches_{i}_rank{rank}.device{device}",
171-
)
172-
set_data_ipc(
173-
self.gpu_cache_kvs[f"value_caches_{i}_rank{rank}_device{device}"],
174-
f"value_caches_{i}_rank{rank}.device{device}",
175-
)
176-
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
177-
logger.info(f"device :{self.device}")
178-
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
179-
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
180-
181-
paddle.set_device("cpu")
182-
self.k_dst_ptrs = []
183-
self.v_dst_ptrs = []
184-
for i in range(args.num_layers + self.num_extra_layers):
185-
self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"] = cuda_host_alloc(
186-
args.num_cpu_blocks * args.bytes_per_layer_per_block
187-
)
188-
self.k_dst_ptrs.append(self.cpu_cache_kvs[f"key_caches_{i}_rank{rank}"])
189-
self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"] = cuda_host_alloc(
190-
args.num_cpu_blocks * args.bytes_per_layer_per_block
191-
)
192-
self.v_dst_ptrs.append(self.cpu_cache_kvs[f"value_caches_{i}_rank{rank}"])
193-
194146
cache_ready_signal_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
195147
self.cache_ready_signal = IPCSignal(
196148
name="cache_ready_signal",
197149
array=cache_ready_signal_data,
198150
dtype=np.int32,
199-
suffix=args.engine_pid,
151+
suffix=self.engine_pid,
152+
create=False,
153+
)
154+
swap_space_ready_data = np.zeros(shape=[args.mp_num], dtype=np.int32)
155+
self.swap_space_ready_signal = IPCSignal(
156+
name="swap_space_ready_signal",
157+
array=swap_space_ready_data,
158+
dtype=np.int32,
159+
suffix=self.engine_pid,
200160
create=False,
201161
)
202-
self.cache_ready_signal.value[self.rank] = 1
162+
163+
self.num_cpu_blocks = args.num_cpu_blocks
164+
165+
self._init_cpu_cache(args)
166+
self._init_gpu_cache(args)
203167

204168
paddle.set_device(f"gpu:{device}")
205169
if args.enable_splitwise:
@@ -232,6 +196,72 @@ def __init__(self, args):
232196
create=False,
233197
)
234198

199+
threading.Thread(target=self.clear_or_update_caches, args=[args], daemon=True).start()
200+
201+
def _init_gpu_cache(self, args):
202+
203+
if not args.create_cache_tensor:
204+
logger.info("Waiting for runners to create kv cache.")
205+
while self.cache_ready_signal.value[self.rank] != 1:
206+
time.sleep(1)
207+
logger.info("OK! Stop waiting.")
208+
209+
logger.info("Initializing kv cache for all layers.")
210+
paddle.set_device(f"gpu:{self.device}")
211+
for i in range(args.num_layers + self.num_extra_layers):
212+
num_gpu_blocks = args.num_gpu_blocks if i < args.num_layers else self.num_extra_layer_gpu_blocks
213+
cache_shape = [num_gpu_blocks, args.kv_num_head, args.block_size, args.head_dim]
214+
key_name = f"key_caches_{i}_rank{self.rank}.device{self.device}"
215+
val_name = f"value_caches_{i}_rank{self.rank}.device{self.device}"
216+
217+
if args.create_cache_tensor:
218+
logger.info(f"..creating kv cache for layer {i}: {cache_shape}")
219+
key_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
220+
val_cache = paddle.full(shape=cache_shape, fill_value=0, dtype=args.cache_dtype)
221+
set_data_ipc(key_cache, key_name)
222+
set_data_ipc(val_cache, val_name)
223+
else:
224+
logger.info(f"..attaching kv cache for layer {i}: {cache_shape}")
225+
key_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
226+
val_cache = paddle.empty(shape=[], dtype=args.cache_dtype)
227+
key_cache = share_external_data(key_cache, key_name, cache_shape)
228+
val_cache = share_external_data(val_cache, val_name, cache_shape)
229+
230+
self.gpu_cache_kvs[key_name] = key_cache
231+
self.gpu_cache_kvs[val_name] = val_cache
232+
self.gpu_cache_k_tensors.append(self.gpu_cache_kvs[key_name])
233+
self.gpu_cache_v_tensors.append(self.gpu_cache_kvs[val_name])
234+
235+
if args.create_cache_tensor:
236+
logger.info("✅ kv cache is ready!")
237+
self.cache_ready_signal.value[self.rank] = 1
238+
239+
cache_kv_size_byte = sum([tmp.numel() * 1 for key, tmp in self.gpu_cache_kvs.items()])
240+
logger.info(f"device :{self.device}")
241+
logger.info(f"cache_kv_size_byte : {cache_kv_size_byte}")
242+
logger.info(f"done init cache (full) gmem alloc : {paddle.device.cuda.memory_allocated()}")
243+
244+
def _init_cpu_cache(self, args):
245+
if args.num_cpu_blocks == 0:
246+
logger.info("💡 no swap space (cpu cache) is specified.")
247+
self.swap_space_ready_signal.value[self.rank] = 1
248+
return
249+
logger.info("Initializing swap space (cpu cache) for all layers.")
250+
paddle.set_device("cpu")
251+
self.k_dst_ptrs = []
252+
self.v_dst_ptrs = []
253+
for i in range(args.num_layers + self.num_extra_layers):
254+
key_name = f"key_caches_{i}_rank{self.rank}"
255+
val_name = f"value_caches_{i}_rank{self.rank}"
256+
need_to_allocate_bytes = args.num_cpu_blocks * args.bytes_per_layer_per_block
257+
logger.info(f"..creating cpu cache for layer {i}: {2 * need_to_allocate_bytes / 1024 ** 3:.2f}GB")
258+
self.cpu_cache_kvs[key_name] = cuda_host_alloc(need_to_allocate_bytes)
259+
self.k_dst_ptrs.append(self.cpu_cache_kvs[key_name])
260+
self.cpu_cache_kvs[val_name] = cuda_host_alloc(need_to_allocate_bytes)
261+
self.v_dst_ptrs.append(self.cpu_cache_kvs[val_name])
262+
logger.info("✅ swap space (cpu cache) is ready!")
263+
self.swap_space_ready_signal.value[self.rank] = 1
264+
235265
def _do_swap_to_cpu_task(
236266
self,
237267
swap_node_ids,
@@ -429,6 +459,67 @@ def _transfer_data(
429459
transfer_task_id,
430460
)
431461

462+
def clear_or_update_caches(self, args):
463+
logger.info("Start a thread to clear/restore kv cache when model weights are cleared/updated.")
464+
logger.info(f"FD_ENABLE_SWAP_SPACE_CLEARING={envs.FD_ENABLE_SWAP_SPACE_CLEARING}")
465+
kv_cache_status = np.zeros([1], dtype=np.int32)
466+
kv_cache_status_signal = IPCSignal(
467+
name="kv_cache_status",
468+
array=kv_cache_status,
469+
dtype=np.int32,
470+
suffix=self.engine_pid,
471+
create=False,
472+
)
473+
while True:
474+
if kv_cache_status_signal.value[0] == KVCacheStatus.CLEARING:
475+
try:
476+
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
477+
paddle.set_device("cpu")
478+
for ptrs in self.k_dst_ptrs + self.v_dst_ptrs:
479+
cuda_host_free(ptrs)
480+
self.cpu_cache_kvs.clear()
481+
self.k_dst_ptrs.clear()
482+
self.v_dst_ptrs.clear()
483+
gc.collect()
484+
# reset swap_space_ready_signal
485+
self.swap_space_ready_signal.value[self.rank] = 0
486+
while np.sum(self.swap_space_ready_signal.value) != 0:
487+
time.sleep(0.1)
488+
489+
paddle.set_device(f"gpu:{self.device}")
490+
for name, tensor in self.gpu_cache_kvs.items():
491+
unset_data_ipc(tensor, name, True, False)
492+
self.gpu_cache_kvs.clear()
493+
self.gpu_cache_k_tensors.clear()
494+
self.gpu_cache_v_tensors.clear()
495+
# reset cache_ready_signal
496+
self.cache_ready_signal.value[self.rank] = 0
497+
if np.sum(self.cache_ready_signal.value) == 0:
498+
time.sleep(0.1)
499+
500+
kv_cache_status_signal.value[0] = KVCacheStatus.CLEARED
501+
502+
except Exception as e:
503+
logger.error(f"Failed to clear caches: {e}")
504+
505+
elif kv_cache_status_signal.value[0] == KVCacheStatus.UPDATING:
506+
try:
507+
if envs.FD_ENABLE_SWAP_SPACE_CLEARING:
508+
self._init_cpu_cache(args)
509+
while np.sum(self.swap_space_ready_signal.value) != args.mp_num:
510+
time.sleep(0.1)
511+
512+
self._init_gpu_cache(args)
513+
while np.sum(self.cache_ready_signal.value) != args.mp_num:
514+
time.sleep(0.1)
515+
516+
kv_cache_status_signal.value[0] = KVCacheStatus.NORMAL
517+
518+
except Exception as e:
519+
logger.error(f"Failed to restore caches: {e}")
520+
521+
time.sleep(0.1)
522+
432523

433524
def main():
434525
"""

0 commit comments

Comments
 (0)