2424
2525import numpy as np
2626import paddle
27+ import yaml
2728from paddleformers .utils .log import logger
2829
2930from fastdeploy .config import FDConfig
@@ -67,8 +68,6 @@ def _capture_model_state(self):
6768 """Capture and store initial model parameters state."""
6869 for model in self .model_list :
6970 for name , param in model .state_dict ().items ():
70- if hasattr (param , "_is_initialized" ) and not param ._is_initialized ():
71- param .initialize ()
7271 logger .info (f"Model param: { name } , shape={ param .shape } , dtype={ param .dtype } , place={ param .place } " )
7372 self .state_dict [name ] = param
7473
@@ -93,17 +92,18 @@ def valid_parameters(old_state_dict, new_state_dict):
9392 )
9493 return is_valid
9594
96- if version is None or version == "" :
95+ bootstrap_load = version is None or version == ""
96+ if bootstrap_load :
9797 version = self .read_model_version_from_file ()
9898 if version is None or version == "" :
9999 raise Exception (
100- "rsync model version not set, please set it in 1) {model_version}/version.txt "
100+ "rsync model version not set, please set it in 1) {model_version}/version.yaml "
101101 "or 2) interface arguments 'version'"
102102 )
103103
104104 logger .info (
105105 f"START rank:{ self .local_rank } /{ self .nranks } update_weights_by_rdma, "
106- f"version:{ version } , verify_checksum:{ verify_checksum } "
106+ f"version:{ version } , verify_checksum:{ verify_checksum } , bootstrap_load: { bootstrap_load } "
107107 )
108108
109109 if self .rdma_handle is None :
@@ -128,8 +128,14 @@ def valid_parameters(old_state_dict, new_state_dict):
128128 raise ValueError (error_msg )
129129
130130 update_start = time .perf_counter ()
131- for name , param in old_state_dict .items ():
132- param .set_value (new_state_dict [name ])
131+ for name , target_param in old_state_dict .items ():
132+ new_param = new_state_dict [name ]
133+ if bootstrap_load and not target_param ._is_initialized ():
134+ new_param = new_param .cuda ()
135+ new_param ._share_buffer_to (target_param )
136+ else :
137+ target_param .set_value (new_param )
138+
133139 update_cost = time .perf_counter () - update_start
134140 logger .info (f"params set value cost { update_cost :.2f} seconds" )
135141 total_cost = time .perf_counter () - sync_start
@@ -476,13 +482,23 @@ def _update_shared_status(self, pid: int, status: int) -> None:
476482
477483 def read_model_version_from_file (self ):
478484 model_dir = self .fd_config .model_config .model
479- version_file = os .path .join (model_dir , "version.txt " )
485+ version_file = os .path .join (model_dir , "version.yaml " )
480486 try :
481487 with open (version_file , "r" , encoding = "utf-8" ) as f :
482- version = f .read ().strip ()
483- return version
484- except (FileNotFoundError , OSError , IOError ) as e :
485- logger .error (f"Failed to read model version file '{ version_file } ': { e } " )
488+ version_info = yaml .safe_load (f ) or {}
489+
490+ if not isinstance (version_info , dict ):
491+ logger .error (f"Failed to read model step from '{ version_file } ': yaml content is not a mapping" )
492+ return None
493+
494+ step = version_info .get ("step" )
495+ if step is None :
496+ logger .error (f"Failed to read model step from '{ version_file } ': missing 'step' field" )
497+ return None
498+
499+ return str (step )
500+ except (FileNotFoundError , OSError , IOError , yaml .YAMLError ) as e :
501+ logger .error (f"Failed to read model step from '{ version_file } ': { e } " )
486502 return None
487503
488504 @staticmethod
0 commit comments