Skip to content

Commit bccf388

Browse files
authored
fix oom bug, optimize async weight loading and update read step by yaml (#7171) (#7175)
1 parent 1f29389 commit bccf388

1 file changed

Lines changed: 28 additions & 12 deletions

File tree

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import numpy as np
2626
import paddle
27+
import yaml
2728
from paddleformers.utils.log import logger
2829

2930
from fastdeploy.config import FDConfig
@@ -67,8 +68,6 @@ def _capture_model_state(self):
6768
"""Capture and store initial model parameters state."""
6869
for model in self.model_list:
6970
for name, param in model.state_dict().items():
70-
if hasattr(param, "_is_initialized") and not param._is_initialized():
71-
param.initialize()
7271
logger.info(f"Model param: {name}, shape={param.shape}, dtype={param.dtype}, place={param.place}")
7372
self.state_dict[name] = param
7473

@@ -93,17 +92,18 @@ def valid_parameters(old_state_dict, new_state_dict):
9392
)
9493
return is_valid
9594

96-
if version is None or version == "":
95+
bootstrap_load = version is None or version == ""
96+
if bootstrap_load:
9797
version = self.read_model_version_from_file()
9898
if version is None or version == "":
9999
raise Exception(
100-
"rsync model version not set, please set it in 1) {model_version}/version.txt "
100+
"rsync model version not set, please set it in 1) {model_version}/version.yaml "
101101
"or 2) interface arguments 'version'"
102102
)
103103

104104
logger.info(
105105
f"START rank:{self.local_rank}/{self.nranks} update_weights_by_rdma, "
106-
f"version:{version}, verify_checksum:{verify_checksum}"
106+
f"version:{version}, verify_checksum:{verify_checksum}, bootstrap_load:{bootstrap_load}"
107107
)
108108

109109
if self.rdma_handle is None:
@@ -128,8 +128,14 @@ def valid_parameters(old_state_dict, new_state_dict):
128128
raise ValueError(error_msg)
129129

130130
update_start = time.perf_counter()
131-
for name, param in old_state_dict.items():
132-
param.set_value(new_state_dict[name])
131+
for name, target_param in old_state_dict.items():
132+
new_param = new_state_dict[name]
133+
if bootstrap_load and not target_param._is_initialized():
134+
new_param = new_param.cuda()
135+
new_param._share_buffer_to(target_param)
136+
else:
137+
target_param.set_value(new_param)
138+
133139
update_cost = time.perf_counter() - update_start
134140
logger.info(f"params set value cost {update_cost:.2f} seconds")
135141
total_cost = time.perf_counter() - sync_start
@@ -476,13 +482,23 @@ def _update_shared_status(self, pid: int, status: int) -> None:
476482

477483
def read_model_version_from_file(self):
478484
model_dir = self.fd_config.model_config.model
479-
version_file = os.path.join(model_dir, "version.txt")
485+
version_file = os.path.join(model_dir, "version.yaml")
480486
try:
481487
with open(version_file, "r", encoding="utf-8") as f:
482-
version = f.read().strip()
483-
return version
484-
except (FileNotFoundError, OSError, IOError) as e:
485-
logger.error(f"Failed to read model version file '{version_file}': {e}")
488+
version_info = yaml.safe_load(f) or {}
489+
490+
if not isinstance(version_info, dict):
491+
logger.error(f"Failed to read model step from '{version_file}': yaml content is not a mapping")
492+
return None
493+
494+
step = version_info.get("step")
495+
if step is None:
496+
logger.error(f"Failed to read model step from '{version_file}': missing 'step' field")
497+
return None
498+
499+
return str(step)
500+
except (FileNotFoundError, OSError, IOError, yaml.YAMLError) as e:
501+
logger.error(f"Failed to read model step from '{version_file}': {e}")
486502
return None
487503

488504
@staticmethod

0 commit comments

Comments
 (0)