Skip to content

Commit ec5e8aa

Browse files
authored
Merge branch 'InternLM:main' into turbomind-linear-gdn-prefix-caching
2 parents 6c38cf9 + 90245a3 commit ec5e8aa

3 files changed

Lines changed: 75 additions & 2 deletions

File tree

lmdeploy/pytorch/engine/executor/ray_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ async def forward_async(self, inputs):
489489
finally:
490490
# free ray.put inputs
491491
try:
492-
ray._private.internal_api.free(self._prev_inputs)
492+
ray.internal.free(self._prev_inputs, local_only=False)
493493
except Exception as e:
494494
logger.warning(f'Free input ref failed: {e}')
495495

lmdeploy/pytorch/engine/model_agent/agent.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,7 @@ async def sleep(self, level: int = 1):
12081208
if self.dist_config.dp > 1:
12091209
await self.state.to_sleep.wait()
12101210
self.cache_engine = None
1211+
self.state_cache_engine = None
12111212
self.reset_graph_runner()
12121213
device = 'cpu' if level == 1 else 'meta'
12131214
self.patched_model.get_model().to(device=device, non_blocking=True)
@@ -1245,4 +1246,5 @@ def release(self):
12451246
self.reset_graph_runner()
12461247
self.patched_model = None
12471248
self.cache_engine = None
1249+
self.state_cache_engine = None
12481250
torch.cuda.empty_cache()

lmdeploy/pytorch/models/qwen3_vl_moe.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from lmdeploy.pytorch.model_inputs import StepContextManager
1111
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
1212

13-
from .patch import add_prefix
13+
from .patch import add_prefix, get_build_model_context
1414
from .qwen3_moe import Qwen3MoeModel
1515
from .qwen3_vl import Qwen3VLForConditionalGeneration
1616
from .qwen3_vl import Qwen3VLTextRotaryEmbedding as Qwen3VLMoeTextRotaryEmbedding
@@ -44,6 +44,7 @@ def forward(
4444
# args for deepstack
4545
visual_pos_masks: torch.Tensor | None = None,
4646
deepstack_visual_embeds: list[torch.Tensor] | None = None,
47+
all_routed_experts: torch.Tensor | None = None,
4748
):
4849
"""Rewrite of LlamaModel.forward."""
4950

@@ -73,6 +74,7 @@ def forward(
7374
past_key_value=past_key_value,
7475
residual=residual,
7576
attn_metadata=attn_metadata,
77+
all_routed_experts=all_routed_experts,
7678
)
7779

7880
# add visual features to the hidden states of first several layers
@@ -129,6 +131,75 @@ def __init__(
129131
dtype=dtype,
130132
device=device,
131133
prefix=add_prefix('language_model', prefix))
134+
# for router replay
135+
bm_ctx = get_build_model_context()
136+
self.enable_return_routed_experts = bm_ctx.enable_return_routed_experts
137+
138+
def forward(
139+
self,
140+
input_ids: torch.Tensor,
141+
position_ids: torch.Tensor,
142+
past_key_values: list[list[torch.Tensor]],
143+
attn_metadata: Any = None,
144+
inputs_embeds: torch.Tensor = None,
145+
mrope_position_ids: torch.Tensor = None,
146+
pixel_values: torch.Tensor = None,
147+
vis_cu_seqlens: torch.Tensor = None,
148+
vis_pos_emb: torch.Tensor = None,
149+
image_mask: torch.Tensor = None,
150+
pos_embeds: torch.Tensor = None,
151+
grid_thw: torch.Tensor = None,
152+
**kwargs,
153+
):
154+
"""Model forward, return logits."""
155+
156+
visual_pos_masks = None
157+
deepstack_visual_embeds = None
158+
if inputs_embeds is None:
159+
inputs_embeds = self.get_input_embeddings()(input_ids)
160+
161+
if pixel_values is not None:
162+
dtype = inputs_embeds.dtype
163+
pixel_values = pixel_values.to(dtype)
164+
vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype))
165+
166+
# get image embeds and deepstack visual embeds
167+
image_embeds, deepstack_visual_embeds = self.visual(pixel_values,
168+
cu_seqlens=vis_cu_seqlens,
169+
rotary_pos_emb=vis_pos_emb,
170+
pos_embeds=pos_embeds)
171+
172+
# split image embeds per sample
173+
split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
174+
image_embeds = torch.split(image_embeds, split_sizes)
175+
image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype)
176+
177+
# mask and scatter to create final input embeddings
178+
expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds)
179+
inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds)
180+
181+
visual_pos_masks = expanded_image_mask
182+
183+
# router replay
184+
all_routed_experts = None
185+
if self.enable_return_routed_experts:
186+
all_routed_experts = input_ids.new_empty((input_ids.size(1), self.config.text_config.num_hidden_layers,
187+
self.config.text_config.num_experts_per_tok),
188+
dtype=torch.uint16)
189+
hidden_states = self.language_model(
190+
input_ids=input_ids,
191+
position_ids=position_ids,
192+
past_key_values=past_key_values,
193+
attn_metadata=attn_metadata,
194+
inputs_embeds=inputs_embeds,
195+
mrope_position_ids=mrope_position_ids,
196+
# args for deepstack
197+
visual_pos_masks=visual_pos_masks,
198+
deepstack_visual_embeds=deepstack_visual_embeds,
199+
all_routed_experts=all_routed_experts)
200+
if all_routed_experts is None:
201+
return hidden_states
202+
return dict(hidden_states=hidden_states, all_routed_experts=all_routed_experts)
132203

133204
def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: dict[str, nn.Parameter],
134205
expert_params_mapping: list):

0 commit comments

Comments
 (0)