Skip to content

Commit 0fa28b1

Browse files
[fix] fix ep group all-reduce (#4140)
* [fix] fix ep group all-reduce * [fix] fix clear/update lock not working when workers > 1 * [chore] add preemption triggered info log * [fix] fix code style * fix model_weights_signal (#4092) * fix model_weights_signal --------- Co-authored-by: Yuanle Liu <yuanlehome@163.com>
1 parent cffde70 commit 0fa28b1

6 files changed

Lines changed: 41 additions & 26 deletions

File tree

fastdeploy/config.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,12 @@ def set_tp_group(self):
352352
)
353353
dist.collective._set_custom_gid(None)
354354
# same ep group id
355-
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
356-
self.ep_group = dist.new_group(range(self.expert_parallel_size))
355+
# dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
356+
# self.ep_group = dist.new_group(range(self.expert_parallel_size))
357+
if self.enable_expert_parallel:
358+
dist.collective._set_custom_gid(self.data_parallel_size + tp_gid_offset)
359+
self.ep_group = dist.new_group(range(self.expert_parallel_size))
360+
dist.collective._set_custom_gid(None)
357361
logger.info(
358362
f"data_parallel_size: {self.data_parallel_size}, tensor_parallel_size: {self.tensor_parallel_size}, expert_parallel_size: {self.expert_parallel_size}, data_parallel_rank: {self.data_parallel_rank}, tensor_parallel_rank: {self.tensor_parallel_rank}, expert_parallel_rank: {self.expert_parallel_rank}, tp_group: {self.tp_group}."
359363
)
@@ -1339,7 +1343,7 @@ def check(self):
13391343
)
13401344
if self.scheduler_config is not None:
13411345
self.scheduler_config.check()
1342-
1346+
13431347
if int(envs.ENABLE_V1_KVCACHE_SCHEDULER) == 1:
13441348
assert (
13451349
int(envs.FD_DISABLED_RECOVER) == 0

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def _trigger_preempt(self, request, num_new_blocks, preempted_reqs, scheduled_re
120120
self._free_blocks(preempted_req)
121121
preempted_req.cached_block_num = 0
122122
self.to_be_rescheduled_request_id_set.add(preempted_req.request_id)
123+
llm_logger.info(f"Preemption is triggered! Preempted request id: {preempted_req.request_id}")
123124
preempted_reqs.append(preempted_req)
124125
scheduled_reqs.append(self._prepare_preempt_task(preempted_req))
125126
main_process_metrics.num_requests_waiting.inc(1)

fastdeploy/entrypoints/engine_client.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616

1717
import inspect
1818
import os
19-
import threading
2019
import time
2120
import traceback
2221
import uuid
2322

2423
import numpy as np
24+
from filelock import FileLock
2525

2626
from fastdeploy import envs
2727
from fastdeploy.config import ModelConfig
@@ -132,7 +132,7 @@ def __init__(
132132
pid, max_connections=int(os.getenv("FD_DEALER_CONNECTIONS", 50))
133133
)
134134
self.connection_initialized = False
135-
self.clear_update_lock = threading.Lock()
135+
self.clear_update_lock = FileLock(f"/tmp/fd_weight_clear_update_lock__pid{pid}_port{port}.lock")
136136

137137
def create_zmq_client(self, model, mode):
138138
"""
@@ -351,7 +351,9 @@ def update_model_weight(self, timeout=300):
351351
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.NORMAL:
352352
return True, ""
353353
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
354-
return False, "updating model weight already"
354+
return False, "worker is updating model weight already"
355+
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
356+
return False, "worker is clearing model weight, cannot update now"
355357

356358
self.model_weights_status_signal.value[0] = ModelWeightsStatus.UPDATING
357359
if self.enable_prefix_caching or self.enable_splitwise:
@@ -395,7 +397,9 @@ def clear_load_weight(self, timeout=300):
395397
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARED:
396398
return True, ""
397399
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.CLEARING:
398-
return False, "clearing model weight already"
400+
return False, "worker is clearing model weight already"
401+
if self.model_weights_status_signal.value[0] == ModelWeightsStatus.UPDATING:
402+
return False, "worker is updating model weight, cannot clear now"
399403

400404
self.model_weights_status_signal.value[0] = ModelWeightsStatus.CLEARING
401405
if self.enable_prefix_caching or self.enable_splitwise:

fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def apply_tp(
297297
)
298298

299299
if layer.reduce_results and layer.tp_size > 1:
300-
tensor_model_parallel_all_reduce(fused_moe_out)
300+
tensor_model_parallel_all_reduce(fused_moe_out, layer.fd_config.parallel_config.tp_group)
301301

302302
return fused_moe_out
303303

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -220,23 +220,17 @@ def check_model_weights_status(model_weights_status, model_runner, pid):
220220
check model weights status
221221
"""
222222
logger.info(f"dynamic weight manager is check model weights status! {model_weights_status.value[0]}")
223-
is_stop = 0
224223
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
225224
if model_weights_status.value[0] == ModelWeightsStatus.UPDATING:
226225
logger.info("infer engine stopped! start to load new checkpoint...")
227226
model_runner.update_parameters(pid)
227+
while model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
228+
time.sleep(0.01)
229+
logger.info("finished loading new checkpoint")
228230
elif model_weights_status.value[0] == ModelWeightsStatus.CLEARING:
229231
logger.info("infer engine stopped! start to clear checkpoint...")
230232
model_runner.clear_parameters(pid)
231-
while True:
232-
if model_weights_status.value[0] == ModelWeightsStatus.NORMAL:
233-
logger.info("finished loading new checkpoint")
234-
break
235-
elif is_stop == 1 or (model_weights_status.value[0] == ModelWeightsStatus.CLEARED and is_stop == 0):
236-
if is_stop == 0:
237-
logger.info("finished clearing checkpoint")
238-
is_stop = 1
239-
time.sleep(0.001)
240-
break
241-
else:
242-
time.sleep(0.001)
233+
while model_weights_status.value[0] != ModelWeightsStatus.CLEARED:
234+
time.sleep(0.01)
235+
logger.info("finished clearing checkpoint")
236+
time.sleep(0.01)

fastdeploy/worker/worker_process.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,11 @@ def init_health_status(self) -> None:
270270
create=False,
271271
)
272272

273+
def _broadcast_model_weights_signal(self, src: int, group) -> int:
274+
model_weights_signal_tensor = paddle.full(shape=[1], fill_value=self.model_weights_signal[0], dtype="int32")
275+
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
276+
return model_weights_signal_tensor.item()
277+
273278
def event_loop_normal(self) -> None:
274279
"""Main event loop for Paddle Distrubuted Workers.
275280
TODO(gongshaotian): support remote calling of functions that control worker.
@@ -279,15 +284,19 @@ def event_loop_normal(self) -> None:
279284
req_ids = []
280285
num_running_requests = 0
281286
local_rank = self.local_rank % self.parallel_config.tensor_parallel_size
282-
self.model_weights_signal = paddle.zeros([1], dtype=paddle.int32)
287+
self.model_weights_signal = np.zeros([1], dtype=np.int32)
283288
while True:
284289
if self.local_rank % self.parallel_config.tensor_parallel_size == 0:
285290
if self.model_weights_status.value[0] != ModelWeightsStatus.NORMAL:
286291
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
287292
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.enable_expert_parallel:
288-
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.ep_group)
289-
if self.fd_config.load_config.dynamic_load_weight:
290-
paddle.distributed.broadcast(self.model_weights_signal, src=0, group=self.parallel_config.tp_group)
293+
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
294+
src=0, group=self.parallel_config.ep_group
295+
)
296+
if self.fd_config.load_config.dynamic_load_weight and self.parallel_config.tensor_parallel_size > 1:
297+
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
298+
src=0, group=self.parallel_config.tp_group
299+
)
291300

292301
self.insert_step = False
293302
req_dicts = None
@@ -315,7 +324,9 @@ def event_loop_normal(self) -> None:
315324
else:
316325
paddle.distributed.barrier(self.parallel_config.tp_group)
317326
if self.model_weights_signal[0] != ModelWeightsStatus.NORMAL:
318-
logger.info(f"Rank: {self.local_rank} has updated parameters.")
327+
logger.info(
328+
f"Rank: {self.local_rank} to update or clear parameters, signal is {self.model_weights_signal[0]}, [-1:clear, 1:update]"
329+
)
319330
from fastdeploy.rl.dynamic_weight_manager import (
320331
DynamicWeightManager,
321332
)
@@ -327,6 +338,7 @@ def event_loop_normal(self) -> None:
327338
self.parallel_config.engine_worker_queue_port,
328339
)
329340
self.model_weights_signal[0] = ModelWeightsStatus.NORMAL
341+
logger.info(f"Rank: {self.local_rank} has updated or cleared parameters.")
330342

331343
if self.exist_task_signal.value[0] == ExistTaskStatus.EXIST or self.task_queue.read_finish_flag.get() == 1:
332344
logger.info(f"Rank: {self.local_rank} Detected new requests.")

0 commit comments

Comments
 (0)