Skip to content

Commit 01f6934

Browse files
authored
[Executor] Adjust signal sending order in RL training (#3773) (#4066) (#4178)
* Adjust processing order * fix bug * fix update_parameters bug * refine code
1 parent 7bdc6f4 commit 01f6934

3 files changed

Lines changed: 20 additions & 22 deletions

File tree

fastdeploy/model_executor/graph_optimization/cudagraph_piecewise_backend.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from dataclasses import dataclass
1919
from typing import Callable, Dict, Optional
2020

21+
import paddle.jit.dy2static.utils as jit_utils
2122
import paddle.nn.layer
2223
from paddle.device.cuda import graphs
2324

@@ -51,27 +52,24 @@ class ConcreteSizeEntry:
5152

5253
class Dy2StCudaGraphManager:
5354
def __init__(self):
54-
# NOTE(gongshaotian): Use local import to avoid RLHF version problems
55-
from paddle.jit.dy2static.utils import CUDAGraphState
5655

57-
self.state = CUDAGraphState.DISABLE
56+
self.state = jit_utils.CUDAGraphState.DISABLE
5857
self.captured_batch_size = set()
5958
self.batch_size = -1
6059

6160
def run_impl(self, original_run_impl, inputs, parameters, attrs):
62-
from paddle.jit.dy2static.utils import CUDAGraphState
6361

6462
run_state = self.state
6563
prog_attrs, cuda_graph_attrs = attrs
66-
if run_state == CUDAGraphState.REPLAY:
64+
if run_state == jit_utils.CUDAGraphState.REPLAY:
6765
if self.batch_size not in self.captured_batch_size:
68-
run_state = CUDAGraphState.DISABLE
69-
elif run_state == CUDAGraphState.CAPTURE:
66+
run_state = jit_utils.CUDAGraphState.DISABLE
67+
elif run_state == jit_utils.CUDAGraphState.CAPTURE:
7068
self.captured_batch_size.add(self.batch_size)
7169

7270
cuda_graph_attrs |= {
7371
"cuda_graph_state": run_state,
74-
"cuda_graph_dispatch_key": self.batch_size if run_state != CUDAGraphState.DISABLE else 0,
72+
"cuda_graph_dispatch_key": self.batch_size if run_state != jit_utils.CUDAGraphState.DISABLE else 0,
7573
}
7674
return original_run_impl(inputs, parameters, (prog_attrs, cuda_graph_attrs))
7775

@@ -104,7 +102,6 @@ def __init__(
104102
self.cuda_graph_manager = Dy2StCudaGraphManager()
105103

106104
def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
107-
from paddle.jit.dy2static.utils import CUDAGraphState
108105

109106
if not entry.captured:
110107
# Warmup the model
@@ -121,14 +118,14 @@ def run_static_model(self, entry: ConcreteSizeEntry, **kwargs):
121118
entry.input_addresses = input_addresses
122119

123120
# Capture
124-
self.cuda_graph_manager.state = CUDAGraphState.CAPTURE
121+
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.CAPTURE
125122
self.cuda_graph_manager.batch_size = entry.real_shape
126123
entry.captured = True
127124
with self.cuda_graph_manager.run_impl_guard():
128125
entry.runnable(**kwargs)
129126

130127
# Replay
131-
self.cuda_graph_manager.state = CUDAGraphState.REPLAY
128+
self.cuda_graph_manager.state = jit_utils.CUDAGraphState.REPLAY
132129
self.cuda_graph_manager.batch_size = entry.real_shape
133130
with self.cuda_graph_manager.run_impl_guard():
134131
return entry.runnable(**kwargs)

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __init__(self, fd_config: FDConfig, model: nn.Layer):
4545
self.model: nn.Layer = model
4646
self._capture_model_state()
4747
self.update_parameters()
48+
self.finalize_update()
4849

4950
logger.info(
5051
f"✅ DynamicLoad model built successfully by {self.load_config.load_strategy}, "
@@ -81,8 +82,6 @@ def update_parameters(self, pid: int = 0) -> None:
8182

8283
logger.info(f"Update parameters in {time.perf_counter()-start_time:.2f}s")
8384

84-
self._finalize_update(pid)
85-
8685
def _update_ipc_snapshot(self):
8786
"""Update using IPC snapshot strategy for elastic recovery."""
8887
model_path = os.path.join(
@@ -146,7 +145,7 @@ def _validate_parameter_match(self, name: str, src: paddle.Tensor, dst: paddle.T
146145
if src.shape != dst.shape:
147146
raise ValueError(f"Shape mismatch for {name}: {src.shape} vs {dst.shape}")
148147

149-
def _finalize_update(self, pid: int):
148+
def finalize_update(self, pid: int = 0):
150149
"""Finalize update process with verification."""
151150
self._verify_parameters("update")
152151
if self.parallel_config.tensor_parallel_size > 1:

fastdeploy/worker/gpu_model_runner.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1705,25 +1705,27 @@ def clear_cache(self):
17051705
paddle.device.cuda.empty_cache()
17061706

17071707
def clear_parameters(self, pid):
1708-
""" " Dynamic model loader use to clear parameters use for RL"""
1708+
"""Dynamic model loader use to clear parameters use for RL"""
1709+
# Clear CUDAGraph
1710+
if self.use_cudagraph:
1711+
self.model.clear_grpah_opt_backend()
1712+
# Clear parameters and Send single
17091713
self.dynamic_weight_manager.clear_parameters(pid)
17101714
self.clear_cache()
17111715
paddle.device.cuda.empty_cache()
17121716

1713-
# Clear CudaGraph
1714-
if self.use_cudagraph:
1715-
self.model.clear_grpah_opt_backend()
1716-
17171717
self.dynamic_weight_manager._log_memory("dynamic weight manager clear all memory")
17181718

17191719
def update_parameters(self, pid):
1720-
""" " Dynamic model loader use to update parameters use for RL"""
1720+
"""Dynamic model loader use to update parameters use for RL"""
1721+
# Update parameters
17211722
self.dynamic_weight_manager.update_parameters(pid)
17221723
self.initialize_kv_cache()
1723-
1724-
# Recapture CudaGraph
1724+
# Recapture CUDAGraph
17251725
if self.use_cudagraph:
17261726
self.capture_model()
1727+
# Send single
1728+
self.dynamic_weight_manager.finalize_update(pid)
17271729

17281730
self.dynamic_weight_manager._log_memory("dynamic weight manager update all memory")
17291731

0 commit comments

Comments
 (0)