Skip to content

Commit 25d9517

Browse files
authored
add R3 for qwen3-moe-vl models (#4457)
1 parent 9cbedf6 commit 25d9517

1 file changed

Lines changed: 72 additions & 1 deletion

File tree

lmdeploy/pytorch/models/qwen3_vl_moe.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from lmdeploy.pytorch.model_inputs import StepContextManager
1111
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
1212

13-
from .patch import add_prefix
13+
from .patch import add_prefix, get_build_model_context
1414
from .qwen3_moe import Qwen3MoeModel
1515
from .qwen3_vl import Qwen3VLForConditionalGeneration
1616
from .qwen3_vl import Qwen3VLTextRotaryEmbedding as Qwen3VLMoeTextRotaryEmbedding
@@ -44,6 +44,7 @@ def forward(
4444
# args for deepstack
4545
visual_pos_masks: torch.Tensor | None = None,
4646
deepstack_visual_embeds: list[torch.Tensor] | None = None,
47+
all_routed_experts: torch.Tensor | None = None,
4748
):
4849
"""Rewrite of LlamaModel.forward."""
4950

@@ -73,6 +74,7 @@ def forward(
7374
past_key_value=past_key_value,
7475
residual=residual,
7576
attn_metadata=attn_metadata,
77+
all_routed_experts=all_routed_experts,
7678
)
7779

7880
# add visual features to the hidden states of first several layers
@@ -129,6 +131,75 @@ def __init__(
129131
dtype=dtype,
130132
device=device,
131133
prefix=add_prefix('language_model', prefix))
134+
# for router replay
135+
bm_ctx = get_build_model_context()
136+
self.enable_return_routed_experts = bm_ctx.enable_return_routed_experts
137+
138+
def forward(
139+
self,
140+
input_ids: torch.Tensor,
141+
position_ids: torch.Tensor,
142+
past_key_values: list[list[torch.Tensor]],
143+
attn_metadata: Any = None,
144+
inputs_embeds: torch.Tensor = None,
145+
mrope_position_ids: torch.Tensor = None,
146+
pixel_values: torch.Tensor = None,
147+
vis_cu_seqlens: torch.Tensor = None,
148+
vis_pos_emb: torch.Tensor = None,
149+
image_mask: torch.Tensor = None,
150+
pos_embeds: torch.Tensor = None,
151+
grid_thw: torch.Tensor = None,
152+
**kwargs,
153+
):
154+
"""Model forward, return logits."""
155+
156+
visual_pos_masks = None
157+
deepstack_visual_embeds = None
158+
if inputs_embeds is None:
159+
inputs_embeds = self.get_input_embeddings()(input_ids)
160+
161+
if pixel_values is not None:
162+
dtype = inputs_embeds.dtype
163+
pixel_values = pixel_values.to(dtype)
164+
vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype))
165+
166+
# get image embeds and deepstack visual embeds
167+
image_embeds, deepstack_visual_embeds = self.visual(pixel_values,
168+
cu_seqlens=vis_cu_seqlens,
169+
rotary_pos_emb=vis_pos_emb,
170+
pos_embeds=pos_embeds)
171+
172+
# split image embeds per sample
173+
split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
174+
image_embeds = torch.split(image_embeds, split_sizes)
175+
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype)
176+
177+
# mask and scatter to create final input embeddings
178+
expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds)
179+
inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds)
180+
181+
visual_pos_masks = expanded_image_mask
182+
183+
# router replay
184+
all_routed_experts = None
185+
if self.enable_return_routed_experts:
186+
all_routed_experts = input_ids.new_empty((input_ids.size(1), self.config.text_config.num_hidden_layers,
187+
self.config.text_config.num_experts_per_tok),
188+
dtype=torch.uint16)
189+
hidden_states = self.language_model(
190+
input_ids=input_ids,
191+
position_ids=position_ids,
192+
past_key_values=past_key_values,
193+
attn_metadata=attn_metadata,
194+
inputs_embeds=inputs_embeds,
195+
mrope_position_ids=mrope_position_ids,
196+
# args for deepstack
197+
visual_pos_masks=visual_pos_masks,
198+
deepstack_visual_embeds=deepstack_visual_embeds,
199+
all_routed_experts=all_routed_experts)
200+
if all_routed_experts is None:
201+
return hidden_states
202+
return dict(hidden_states=hidden_states, all_routed_experts=all_routed_experts)
132203

133204
def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: dict[str, nn.Parameter],
134205
expert_params_mapping: list):

0 commit comments

Comments
 (0)