1616
1717import gc
1818import glob
19- import io
2019import os
2120import re
2221import time
3130from fastdeploy .inter_communicator import KVCacheStatus , ModelWeightsStatus
3231
3332
34- def sync_weights_by_rdma (config , step , rank ):
35- from checkpoint_transfer .core import RDMAWeightsDownloader
36-
37- downloader = RDMAWeightsDownloader (config )
38- downloader .initialize ()
39- logger .info (f"Fetching weights for step:{ step } , rank:{ rank } ..." )
40- data = downloader .get_weights (step , rank )
41- if data is None :
42- logger .error ("Failed to get weights!" )
43- raise Exception ("Failed to rsync weights through checkpoint_transfer" )
44- logger .info (f"Successfully retrieved data. Type: { type (data )} " )
45- if isinstance (data , np .ndarray ):
46- data_bytes = data .tobytes ()
47- elif isinstance (data , (bytes , bytearray )):
48- data_bytes = data
49- else :
50- data_bytes = bytes (data )
51- logger .info (f"Data size: { len (data_bytes )} bytes" )
52-
53- buffer = io .BytesIO (data_bytes )
54- new_state_dict = paddle .load (buffer )
55- return new_state_dict
56-
57-
5833class DynamicWeightManager :
5934 """Manages model weights loading, updating and shared state across processes."""
6035
@@ -75,6 +50,7 @@ def __init__(self, fd_config: FDConfig, models, local_rank: int):
7550 else :
7651 self .model_list = models
7752 self ._capture_model_state ()
53+ self .rdma_handle = None
7854 if self .load_config .load_strategy == "rsync" :
7955 self .update_weights_by_rdma ()
8056 else :
@@ -91,10 +67,12 @@ def _capture_model_state(self):
9167 """Capture and store initial model parameters state."""
9268 for model in self .model_list :
9369 for name , param in model .state_dict ().items ():
94- logger .info (f"Model param: { name } , shape={ param .shape } , dtype={ param .dtype } " )
70+ if hasattr (param , "_is_initialized" ) and not param ._is_initialized ():
71+ param .initialize ()
72+ logger .info (f"Model param: { name } , shape={ param .shape } , dtype={ param .dtype } , place={ param .place } " )
9573 self .state_dict [name ] = param
9674
97- def update_weights_by_rdma (self , version : str = None , rsync_config : Dict [ str , Any ] = None ):
75+ def update_weights_by_rdma (self , version : str = None , verify_checksum : bool = False ):
9876 def valid_parameters (old_state_dict , new_state_dict ):
9977 is_valid = True
10078 for key in old_state_dict :
@@ -110,17 +88,11 @@ def valid_parameters(old_state_dict, new_state_dict):
11088 )
11189 elif old_state_dict [key ].dtype != new_state_dict [key ].dtype :
11290 is_valid = False
113- logger .error (f"Invalid parameter: { key } dtype mismatch" )
91+ logger .error (
92+ f"Invalid parameter: { key } dtype mismatch, old:{ old_state_dict [key ].dtype } , new:{ new_state_dict [key ].dtype } "
93+ )
11494 return is_valid
11595
116- if rsync_config is None :
117- rsync_config = self .fd_config .load_config .rsync_config
118- if rsync_config is None or len (rsync_config ) == 0 :
119- raise Exception (
120- "rsync config not set, please set it in 1) launch arguments '--rsync-config' "
121- "or 2) interface arguments 'rsync_config'"
122- )
123-
12496 if version is None or version == "" :
12597 version = self .read_model_version_from_file ()
12698 if version is None or version == "" :
@@ -129,11 +101,23 @@ def valid_parameters(old_state_dict, new_state_dict):
129101 "or 2) interface arguments 'version'"
130102 )
131103
132- logger .info (f"START update_weights_by_rdma, version:{ version } , rsync_config:{ rsync_config } " )
133- rank = self .local_rank
104+ logger .info (
105+ f"START rank:{ self .local_rank } /{ self .nranks } update_weights_by_rdma, "
106+ f"version:{ version } , verify_checksum:{ verify_checksum } "
107+ )
108+
109+ if self .rdma_handle is None :
110+ from checkpoint_transfer import CheckpointTransfer
111+
112+ config = self .fd_config .load_config .rsync_config
113+ logger .info (f"CheckpointTransfer rsync config:{ config } " )
114+ self .rdma_handle = CheckpointTransfer (** config , local_rank = self .local_rank , group_size = self .nranks )
115+ self .rdma_handle .initialize ()
134116
135117 sync_start = time .perf_counter ()
136- new_state_dict = sync_weights_by_rdma (rsync_config , version , rank )
118+ new_state_dict = dict ()
119+ for key , param in self .rdma_handle .receive_stream (step_id = version , verify_checksum = verify_checksum ):
120+ new_state_dict [key ] = param
137121 sync_cost = time .perf_counter () - sync_start
138122 logger .info (f"weights sync cost { sync_cost :.2f} seconds" )
139123
@@ -148,18 +132,17 @@ def valid_parameters(old_state_dict, new_state_dict):
148132 param .set_value (new_state_dict [name ])
149133 update_cost = time .perf_counter () - update_start
150134 logger .info (f"params set value cost { update_cost :.2f} seconds" )
151-
152135 total_cost = time .perf_counter () - sync_start
153136 logger .info (
154137 f"END update_weights_by_rdma, cost { total_cost :.2f} seconds"
155- f" version:{ version } , rsync_config : { rsync_config } " ,
138+ f" version:{ version } , verify_checksum : { verify_checksum } , local_rank: { self . local_rank } " ,
156139 )
157140 return {
158141 "sync_cost" : sync_cost ,
159142 "update_cost" : update_cost ,
160143 "total_cost" : total_cost ,
161144 "version" : version ,
162- "rank" : rank ,
145+ "rank" : self . local_rank ,
163146 }
164147
165148 def update_parameters (self , pid : int = 0 , restart_process_group = False ) -> None :
0 commit comments