Skip to content

Commit f12b7a7

Browse files
authored
support_lastnorm_gather_split_r2.4 (#5925)
* support_lastnorm_gather_split_r2.4 * support_lastnorm_gather_split_r2.4v1 * support_lastnorm_gather_split_r2.4v2
1 parent 741a015 commit f12b7a7

9 files changed

Lines changed: 31 additions & 8 deletions

File tree

fastdeploy/model_executor/layers/normalization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,14 @@ def __init__(
105105
self.tp_rank = self.fd_config.parallel_config.tensor_parallel_rank
106106
self.tp_group = self.fd_config.parallel_config.tp_group
107107
is_input_norm = prefix.endswith(".input_layernorm")
108-
is_last_norm = prefix.endswith(".norm")
108+
self.is_last_norm = prefix.endswith(".norm")
109109
self.split_x = (
110110
self.fd_config.parallel_config.use_sequence_parallel_moe
111111
and self.layer_id == self.fd_config.model_config.moe_layer_start_index
112112
and is_input_norm
113113
)
114114
self.allgather_out = self.fd_config.parallel_config.use_sequence_parallel_moe and (
115-
(self.layer_id > self.fd_config.model_config.moe_layer_start_index and is_input_norm) or is_last_norm
115+
(self.layer_id > self.fd_config.model_config.moe_layer_start_index and is_input_norm)
116116
)
117117

118118
self.init_weight()

fastdeploy/model_executor/models/deepseek_v3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,9 @@ def forward(
594594
)
595595
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
596596

597+
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
598+
out = self.norm.allgather(out, forward_meta.ids_remove_padding.shape[0])
599+
597600
return out
598601

599602

fastdeploy/model_executor/models/ernie4_5_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,9 @@ def forward(
459459

460460
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
461461

462+
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
463+
out = self.norm.allgather(out, forward_meta.ids_remove_padding.shape[0])
464+
462465
if current_platform.is_iluvatar() and forward_meta.attn_backend.mixed:
463466
out = forward_meta.attn_backend.reverse_transpose(out)
464467

fastdeploy/model_executor/models/ernie4_5_mtp.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,10 @@ def forward(
325325
for i in range(self.num_layers):
326326
hidden_states, residual = self.mtp_block[i](forward_meta, hidden_states, residual)
327327

328-
hidden_states = self.norm(hidden_states, residual)[0]
328+
hidden_states = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
329+
330+
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
331+
hidden_states = self.norm.allgather(hidden_states, forward_meta.ids_remove_padding.shape[0])
329332

330333
return hidden_states
331334

@@ -396,7 +399,7 @@ def load_weights(self, weights_iterator) -> None:
396399
),
397400
)
398401

399-
def compute_logits(self, hidden_states: paddle.Tensor):
402+
def compute_logits(self, hidden_states: paddle.Tensor, forward_meta: ForwardMeta):
400403
"""
401404
compute logits
402405
"""

fastdeploy/model_executor/models/ernie4_5_vl/ernie4_5_vl_moe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,10 @@ def forward(
548548
)
549549

550550
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
551+
552+
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
553+
out = self.norm.allgather(out, forward_meta.ids_remove_padding.shape[0])
554+
551555
return out
552556

553557

fastdeploy/model_executor/models/glm4_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,9 @@ def forward(
370370

371371
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
372372

373+
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
374+
out = self.norm.allgather(out, forward_meta.ids_remove_padding.shape[0])
375+
373376
return out
374377

375378

fastdeploy/model_executor/models/gpt_oss.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,12 @@ def forward(self, ids_remove_padding: paddle.Tensor, forward_meta: ForwardMeta):
214214
for i in range(self.num_layers):
215215
hidden_states, residual = self.layers[i](forward_meta, hidden_states, residual)
216216

217-
hidden_states = self.norm(hidden_states, residual)[0]
218-
return hidden_states
217+
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
218+
219+
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
220+
out = self.norm.allgather(out, forward_meta.ids_remove_padding.shape[0])
221+
222+
return out
219223

220224

221225
@ModelRegistry.register_model_class(

fastdeploy/model_executor/models/qwen3moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,9 @@ def forward(
282282

283283
out = self.norm(hidden_states, residual, forward_meta=forward_meta)[0]
284284

285+
if self.norm.is_last_norm and self.norm.fd_config.parallel_config.use_sequence_parallel_moe:
286+
out = self.norm.allgather(out, forward_meta.ids_remove_padding.shape[0])
287+
285288
return out
286289

287290

fastdeploy/spec_decode/mtp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,7 @@ def _propose_cuda(self, step_use_cudagraph: bool = False, is_dummy_run: bool = F
10121012
)
10131013

10141014
# 4. Compute logits, Sample
1015-
logits = self.model.compute_logits(hidden_states)
1015+
logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta)
10161016
if self.enable_logprob and self.enable_draft_logprob and substep == 0:
10171017
first_token_logits = self.model.compute_logits(self.model_inputs["first_token_hidden_states"])
10181018

@@ -1125,7 +1125,7 @@ def _propose_xpu(self, step_use_cudagraph: bool = False, is_dummy_run: bool = Fa
11251125
model_output, self.model_inputs["cum_offsets"], self.forward_meta, self.model_inputs
11261126
)
11271127
# 4. Compute logits, Sample
1128-
logits = self.model.compute_logits(hidden_states)
1128+
logits = self.model.compute_logits(hidden_states, forward_meta=self.forward_meta)
11291129
sampled_token_ids, sampler_output = self.sampler(
11301130
logits,
11311131
self.sampling_metadata,

0 commit comments

Comments
 (0)