Skip to content

Commit 077ec83

Browse files
authored
[RL] Adapt async rollout checkpoint update flow (#7042) (#7084)
* update checkpoint-transfer flow and control update_weights params * test: add update_weights route validation (cherry picked from commit 05f2d95)
1 parent 971fc7c commit 077ec83

9 files changed

Lines changed: 58 additions & 88 deletions

File tree

docs/features/weight_update.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ In FastDeploy >= 2.6, the underlying control-signal communication path is optimi
5050
| `/v1/is_paused` | `GET` | none | Return `{"is_paused": bool}`. |
5151
| `/v1/sleep` | `POST` | `?tags=weight,kv_cache` | Offload selected GPU memory objects. Supported tags are `weight` and `kv_cache`. If omitted, both are used. |
5252
| `/v1/wakeup` | `POST` | `?tags=weight,kv_cache` | Reload previously offloaded weights and/or KV cache. On success, the engine resumes automatically. |
53-
| `/v1/update_weights` | `POST` | JSON `{"version":"...", "rsync_config": {...}}` | Refresh weights in place through the worker control path. This API is intended for remote versioned updates, especially `load_strategy=rsync`. |
53+
| `/v1/update_weights` | `POST` | JSON `{"version":"...", "verify_checksum": false}` | Refresh weights in place through the worker control path. This API is intended for remote versioned updates, especially `load_strategy=rsync`. |
5454

5555
### Compatibility Notes
5656

@@ -114,7 +114,7 @@ After `wakeup` succeeds, FastDeploy automatically calls `resume`.
114114
Current request fields:
115115

116116
- `version`: optional string. Used to choose a target checkpoint version.
117-
- `rsync_config`: optional dictionary. Must contain `etcd_server` when provided.
117+
- `verify_checksum`: optional boolean. Defaults to `false`. Set to `true` to verify data integrity during weight synchronization.
118118

119119
Important semantics:
120120

@@ -186,9 +186,7 @@ curl -X POST http://127.0.0.1:8000/v1/update_weights \
186186
-H "Content-Type: application/json" \
187187
-d '{
188188
"version": "global_step_1200",
189-
"rsync_config": {
190-
"etcd_server": "127.0.0.1:2379"
191-
}
189+
"verify_checksum": false
192190
}'
193191
```
194192

@@ -261,9 +259,7 @@ curl -X POST http://127.0.0.1:8000/v1/update_weights \
261259
-H "Content-Type: application/json" \
262260
-d '{
263261
"version": "global_step_1200",
264-
"rsync_config": {
265-
"etcd_server": "127.0.0.1:2379"
266-
}
262+
"verify_checksum": false
267263
}'
268264

269265
# Resume the service after the update completes

docs/zh/features/weight_update.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
5050
| `/v1/is_paused` | `GET` || 返回 `{"is_paused": bool}`|
5151
| `/v1/sleep` | `POST` | `?tags=weight,kv_cache` | 卸载指定 GPU 内存对象。支持 `weight``kv_cache`;不传时默认同时处理两者。 |
5252
| `/v1/wakeup` | `POST` | `?tags=weight,kv_cache` | 重新加载之前被卸载的权重和/或 KV Cache。成功后会自动 `resume`|
53-
| `/v1/update_weights` | `POST` | JSON `{"version":"...", "rsync_config": {...}}` | 通过 worker 控制链路原地刷新模型权重。该接口主要面向 `load_strategy=rsync` 的远端版本更新。 |
53+
| `/v1/update_weights` | `POST` | JSON `{"version":"...", "verify_checksum": false}` | 通过 worker 控制链路原地刷新模型权重。该接口主要面向 `load_strategy=rsync` 的远端版本更新。 |
5454

5555
### 兼容性说明
5656

@@ -113,7 +113,7 @@ python -m fastdeploy.entrypoints.openai.api_server \
113113
当前支持的请求字段:
114114

115115
- `version`:可选字符串,用于指定目标 checkpoint 版本。
116-
- `rsync_config`:可选字典;如果传入,必须包含 `etcd_server`
116+
- `verify_checksum`:可选布尔值;默认为 `false`。设置为 `true` 时,会在权重同步过程中校验数据完整性
117117

118118
关键语义:
119119

@@ -185,9 +185,7 @@ curl -X POST http://127.0.0.1:8000/v1/update_weights \
185185
-H "Content-Type: application/json" \
186186
-d '{
187187
"version": "global_step_1200",
188-
"rsync_config": {
189-
"etcd_server": "127.0.0.1:2379"
190-
}
188+
"verify_checksum": false
191189
}'
192190
```
193191

@@ -260,9 +258,7 @@ curl -X POST http://127.0.0.1:8000/v1/update_weights \
260258
-H "Content-Type: application/json" \
261259
-d '{
262260
"version": "global_step_1200",
263-
"rsync_config": {
264-
"etcd_server": "127.0.0.1:2379"
265-
}
261+
"verify_checksum": false
266262
}'
267263

268264
# 更新完成后恢复服务

fastdeploy/entrypoints/openai/api_server.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -459,19 +459,14 @@ async def update_weights(request: Request) -> Response:
459459
)
460460
args["version"] = request_data["version"]
461461

462-
# Validate and extract rsync_config parameter
463-
if "rsync_config" in request_data and request_data["rsync_config"] is not None:
464-
if not isinstance(request_data["rsync_config"], dict):
462+
# Validate and extract verify_checksum parameter
463+
if "verify_checksum" in request_data and request_data["verify_checksum"] is not None:
464+
if not isinstance(request_data["verify_checksum"], bool):
465465
return JSONResponse(
466466
status_code=400,
467-
content={"error": "Invalid parameter type", "message": "rsync_config must be a dictionary"},
467+
content={"error": "Invalid parameter type", "message": "verify_checksum must be a boolean"},
468468
)
469-
if "etcd_server" not in request_data["rsync_config"]:
470-
return JSONResponse(
471-
status_code=400,
472-
content={"error": "Invalid parameter type", "message": "rsync_config must contain etcd_server"},
473-
)
474-
args["rsync_config"] = request_data["rsync_config"]
469+
args["verify_checksum"] = request_data["verify_checksum"]
475470

476471
control_request = ControlRequest(request_id, "update_weights", args)
477472
control_response = await app.state.engine_client.run_control_method(control_request)

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
import gc
1818
import glob
19-
import io
2019
import os
2120
import re
2221
import time
@@ -31,30 +30,6 @@
3130
from fastdeploy.inter_communicator import KVCacheStatus, ModelWeightsStatus
3231

3332

34-
def sync_weights_by_rdma(config, step, rank):
35-
from checkpoint_transfer.core import RDMAWeightsDownloader
36-
37-
downloader = RDMAWeightsDownloader(config)
38-
downloader.initialize()
39-
logger.info(f"Fetching weights for step:{step}, rank:{rank}...")
40-
data = downloader.get_weights(step, rank)
41-
if data is None:
42-
logger.error("Failed to get weights!")
43-
raise Exception("Failed to rsync weights through checkpoint_transfer")
44-
logger.info(f"Successfully retrieved data. Type: {type(data)}")
45-
if isinstance(data, np.ndarray):
46-
data_bytes = data.tobytes()
47-
elif isinstance(data, (bytes, bytearray)):
48-
data_bytes = data
49-
else:
50-
data_bytes = bytes(data)
51-
logger.info(f"Data size: {len(data_bytes)} bytes")
52-
53-
buffer = io.BytesIO(data_bytes)
54-
new_state_dict = paddle.load(buffer)
55-
return new_state_dict
56-
57-
5833
class DynamicWeightManager:
5934
"""Manages model weights loading, updating and shared state across processes."""
6035

@@ -75,6 +50,7 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int):
7550
else:
7651
self.model_list = models
7752
self._capture_model_state()
53+
self.rdma_handle = None
7854
if self.load_config.load_strategy == "rsync":
7955
self.update_weights_by_rdma()
8056
else:
@@ -91,10 +67,12 @@ def _capture_model_state(self):
9167
"""Capture and store initial model parameters state."""
9268
for model in self.model_list:
9369
for name, param in model.state_dict().items():
94-
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}")
70+
if hasattr(param, "_is_initialized") and not param._is_initialized():
71+
param.initialize()
72+
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}, place={param.place}")
9573
self.state_dict[name] = param
9674

97-
def update_weights_by_rdma(self, version: str = None, rsync_config: Dict[str, Any] = None):
75+
def update_weights_by_rdma(self, version: str = None, verify_checksum: bool = False):
9876
def valid_parameters(old_state_dict, new_state_dict):
9977
is_valid = True
10078
for key in old_state_dict:
@@ -110,17 +88,11 @@ def valid_parameters(old_state_dict, new_state_dict):
11088
)
11189
elif old_state_dict[key].dtype != new_state_dict[key].dtype:
11290
is_valid = False
113-
logger.error(f"Invalid parameter: {key} dtype mismatch")
91+
logger.error(
92+
f"Invalid parameter: {key} dtype mismatch, old:{old_state_dict[key].dtype}, new:{new_state_dict[key].dtype}"
93+
)
11494
return is_valid
11595

116-
if rsync_config is None:
117-
rsync_config = self.fd_config.load_config.rsync_config
118-
if rsync_config is None or len(rsync_config) == 0:
119-
raise Exception(
120-
"rsync config not set, please set it in 1) launch arguments '--rsync-config' "
121-
"or 2) interface arguments 'rsync_config'"
122-
)
123-
12496
if version is None or version == "":
12597
version = self.read_model_version_from_file()
12698
if version is None or version == "":
@@ -129,11 +101,23 @@ def valid_parameters(old_state_dict, new_state_dict):
129101
"or 2) interface arguments 'version'"
130102
)
131103

132-
logger.info(f"START update_weights_by_rdma, version:{version}, rsync_config:{rsync_config}")
133-
rank = self.local_rank
104+
logger.info(
105+
f"START rank:{self.local_rank}/{self.nranks} update_weights_by_rdma, "
106+
f"version:{version}, verify_checksum:{verify_checksum}"
107+
)
108+
109+
if self.rdma_handle is None:
110+
from checkpoint_transfer import CheckpointTransfer
111+
112+
config = self.fd_config.load_config.rsync_config
113+
logger.info(f"CheckpointTransfer rsync config:{config}")
114+
self.rdma_handle = CheckpointTransfer(**config, local_rank=self.local_rank, group_size=self.nranks)
115+
self.rdma_handle.initialize()
134116

135117
sync_start = time.perf_counter()
136-
new_state_dict = sync_weights_by_rdma(rsync_config, version, rank)
118+
new_state_dict = dict()
119+
for key, param in self.rdma_handle.receive_stream(step_id=version, verify_checksum=verify_checksum):
120+
new_state_dict[key] = param
137121
sync_cost = time.perf_counter() - sync_start
138122
logger.info(f"weights sync cost {sync_cost:.2f} seconds")
139123

@@ -148,18 +132,17 @@ def valid_parameters(old_state_dict, new_state_dict):
148132
param.set_value(new_state_dict[name])
149133
update_cost = time.perf_counter() - update_start
150134
logger.info(f"params set value cost {update_cost:.2f} seconds")
151-
152135
total_cost = time.perf_counter() - sync_start
153136
logger.info(
154137
f"END update_weights_by_rdma, cost {total_cost:.2f} seconds"
155-
f" version:{version}, rsync_config: {rsync_config}",
138+
f" version:{version}, verify_checksum: {verify_checksum}, local_rank: {self.local_rank}",
156139
)
157140
return {
158141
"sync_cost": sync_cost,
159142
"update_cost": update_cost,
160143
"total_cost": total_cost,
161144
"version": version,
162-
"rank": rank,
145+
"rank": self.local_rank,
163146
}
164147

165148
def update_parameters(self, pid: int = 0, restart_process_group=False) -> None:

fastdeploy/worker/gpu_model_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import time
2121
from concurrent.futures import Future
2222
from threading import Thread
23-
from typing import Any, Dict, List, Optional, cast
23+
from typing import List, Optional, cast
2424

2525
import numpy as np
2626
import paddle
@@ -2866,8 +2866,8 @@ def update_parameters(self, pid):
28662866
self.dynamic_weight_manager.finalize_update(pid)
28672867
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
28682868

2869-
def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
2870-
return self.dynamic_weight_manager.update_weights_by_rdma(version, rsync_config)
2869+
def update_weights(self, version: str = None, verify_checksum: bool = False):
2870+
return self.dynamic_weight_manager.update_weights_by_rdma(version, verify_checksum)
28712871

28722872
def sleep(self, tags):
28732873

fastdeploy/worker/gpu_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import gc
1818
import time
19-
from typing import Any, Dict, List, Optional
19+
from typing import List, Optional
2020

2121
import paddle
2222
import pynvml
@@ -192,9 +192,9 @@ def initialize_cache(self, num_gpu_blocks: int) -> None:
192192
if self.fd_config.routing_replay_config.enable_routing_replay:
193193
self.model_runner.initialize_routing_replay_manager()
194194

195-
def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
195+
def update_weights(self, version: str = None, verify_checksum: bool = False):
196196
"""update weights in place"""
197-
return self.model_runner.update_weights(version, rsync_config)
197+
return self.model_runner.update_weights(version, verify_checksum)
198198

199199
def sleep(self, **kwargs) -> None:
200200
"""Offload memory from GPU"""

fastdeploy/worker/metax_model_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import time
2121
from concurrent.futures import Future
2222
from threading import Thread
23-
from typing import Any, Dict, List, Optional, cast
23+
from typing import List, Optional, cast
2424

2525
import numpy as np
2626
import paddle
@@ -2769,8 +2769,8 @@ def update_parameters(self, pid):
27692769

27702770
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
27712771

2772-
def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
2773-
return self.dynamic_weight_manager.update_weights_by_rdma(version, rsync_config)
2772+
def update_weights(self, version: str = None, verify_checksum: bool = False):
2773+
return self.dynamic_weight_manager.update_weights_by_rdma(version, verify_checksum)
27742774

27752775
def padding_cudagraph_inputs(self) -> None:
27762776
"""

fastdeploy/worker/metax_worker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import gc
1818
import os
1919
import time
20-
from typing import Any, Dict, List, Optional
20+
from typing import List, Optional
2121

2222
import paddle
2323
from paddle import nn
@@ -191,9 +191,9 @@ def initialize_cache(self, num_gpu_blocks: int) -> None:
191191
# accurate cache size
192192
self.model_runner.update_share_input_block_num(num_gpu_blocks=num_gpu_blocks)
193193

194-
def update_weights(self, version: str = None, rsync_config: Dict[str, Any] = None):
194+
def update_weights(self, version: str = None, verify_checksum: bool = False):
195195
"""update weights in place"""
196-
return self.model_runner.update_weights(version, rsync_config)
196+
return self.model_runner.update_weights(version, verify_checksum)
197197

198198
def execute_model(
199199
self,

tests/entrypoints/openai/test_api_server.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -604,25 +604,25 @@ async def test_update_weights_route_validation():
604604
api_server.app.state.engine_client.run_control_method = AsyncMock(return_value=mock_control_response)
605605

606606
valid_req = MagicMock()
607-
valid_req.body = AsyncMock(return_value=b'{"version":"v2","rsync_config":{"etcd_server":"127.0.0.1"}}')
608-
valid_req.json = AsyncMock(return_value={"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}})
607+
valid_req.body = AsyncMock(return_value=b'{"version":"v2","verify_checksum":true}')
608+
valid_req.json = AsyncMock(return_value={"version": "v2", "verify_checksum": True})
609609
valid_resp = await api_server.update_weights(valid_req)
610610
assert valid_resp.status_code == 200
611611
control_request = api_server.app.state.engine_client.run_control_method.await_args.args[0]
612612
assert control_request.method == "update_weights"
613-
assert control_request.args == {"version": "v2", "rsync_config": {"etcd_server": "127.0.0.1"}}
613+
assert control_request.args == {"version": "v2", "verify_checksum": True}
614614

615615
invalid_version_req = MagicMock()
616616
invalid_version_req.body = AsyncMock(return_value=b'{"version":1}')
617617
invalid_version_req.json = AsyncMock(return_value={"version": 1})
618618
invalid_version_resp = await api_server.update_weights(invalid_version_req)
619619
assert invalid_version_resp.status_code == 400
620620

621-
invalid_rsync_req = MagicMock()
622-
invalid_rsync_req.body = AsyncMock(return_value=b'{"rsync_config":{"user":"u"}}')
623-
invalid_rsync_req.json = AsyncMock(return_value={"rsync_config": {"user": "u"}})
624-
invalid_rsync_resp = await api_server.update_weights(invalid_rsync_req)
625-
assert invalid_rsync_resp.status_code == 400
621+
invalid_checksum_req = MagicMock()
622+
invalid_checksum_req.body = AsyncMock(return_value=b'{"verify_checksum":"true"}')
623+
invalid_checksum_req.json = AsyncMock(return_value={"verify_checksum": "true"})
624+
invalid_checksum_resp = await api_server.update_weights(invalid_checksum_req)
625+
assert invalid_checksum_resp.status_code == 400
626626

627627

628628
@pytest.mark.asyncio

0 commit comments

Comments
 (0)