@@ -270,6 +270,11 @@ def init_health_status(self) -> None:
270270 create = False ,
271271 )
272272
273+ def _broadcast_model_weights_signal (self , src : int , group ) -> int :
274+ model_weights_signal_tensor = paddle .full (shape = [1 ], fill_value = self .model_weights_signal [0 ], dtype = "int32" )
275+ paddle .distributed .broadcast (model_weights_signal_tensor , src = src , group = group )
276+ return model_weights_signal_tensor .item ()
277+
273278 def event_loop_normal (self ) -> None :
274279 """Main event loop for Paddle Distrubuted Workers.
275280 TODO(gongshaotian): support remote calling of functions that control worker.
@@ -279,15 +284,19 @@ def event_loop_normal(self) -> None:
279284 req_ids = []
280285 num_running_requests = 0
281286 local_rank = self .local_rank % self .parallel_config .tensor_parallel_size
282- self .model_weights_signal = paddle .zeros ([1 ], dtype = paddle .int32 )
287+ self .model_weights_signal = np .zeros ([1 ], dtype = np .int32 )
283288 while True :
284289 if self .local_rank % self .parallel_config .tensor_parallel_size == 0 :
285290 if self .model_weights_status .value [0 ] != ModelWeightsStatus .NORMAL :
286291 self .model_weights_signal [0 ] = int (self .model_weights_status .value [0 ])
287292 if self .fd_config .load_config .dynamic_load_weight and self .parallel_config .enable_expert_parallel :
288- paddle .distributed .broadcast (self .model_weights_signal , src = 0 , group = self .parallel_config .ep_group )
289- if self .fd_config .load_config .dynamic_load_weight :
290- paddle .distributed .broadcast (self .model_weights_signal , src = 0 , group = self .parallel_config .tp_group )
293+ self .model_weights_signal [0 ] = self ._broadcast_model_weights_signal (
294+ src = 0 , group = self .parallel_config .ep_group
295+ )
296+ if self .fd_config .load_config .dynamic_load_weight and self .parallel_config .tensor_parallel_size > 1 :
297+ self .model_weights_signal [0 ] = self ._broadcast_model_weights_signal (
298+ src = 0 , group = self .parallel_config .tp_group
299+ )
291300
292301 self .insert_step = False
293302 req_dicts = None
@@ -315,7 +324,9 @@ def event_loop_normal(self) -> None:
315324 else :
316325 paddle .distributed .barrier (self .parallel_config .tp_group )
317326 if self .model_weights_signal [0 ] != ModelWeightsStatus .NORMAL :
318- logger .info (f"Rank: { self .local_rank } has updated parameters." )
327+ logger .info (
328+ f"Rank: { self .local_rank } to update or clear parameters, signal is { self .model_weights_signal [0 ]} , [-1:clear, 1:update]"
329+ )
319330 from fastdeploy .rl .dynamic_weight_manager import (
320331 DynamicWeightManager ,
321332 )
@@ -327,6 +338,7 @@ def event_loop_normal(self) -> None:
327338 self .parallel_config .engine_worker_queue_port ,
328339 )
329340 self .model_weights_signal [0 ] = ModelWeightsStatus .NORMAL
341+ logger .info (f"Rank: { self .local_rank } has updated or cleared parameters." )
330342
331343 if self .exist_task_signal .value [0 ] == ExistTaskStatus .EXIST or self .task_queue .read_finish_flag .get () == 1 :
332344 logger .info (f"Rank: { self .local_rank } Detected new requests." )
0 commit comments