Skip to content

Commit 05a563e

Browse files
authored
Fix model config (#4170)
* fix config * fix moe
1 parent dda27d0 commit 05a563e

3 files changed

Lines changed: 23 additions & 9 deletions

File tree

lmdeploy/pytorch/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str):
2929
config.dtype = torch.float16
3030
return config
3131

32-
torch_dtype = getattr(config.llm_config, 'dtype', None)
32+
torch_dtype = getattr(config.hf_config, 'dtype', None)
3333
if torch_dtype is None:
34-
torch_dtype = getattr(config.llm_config, 'torch_dtype', None)
34+
torch_dtype = getattr(config.hf_config, 'torch_dtype', None)
3535

3636
# deal with case when torch_dtype is not string but torch.dtype
3737
if isinstance(torch_dtype, torch.dtype):

lmdeploy/pytorch/configurations/default.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,6 @@ def condition(cls, hf_config):
1414
@classmethod
1515
def build(cls, hf_config, model_path: str = None, **kwargs):
1616
"""build."""
17-
18-
# for multi-modal models, get the language model config to build model config
19-
if hasattr(hf_config, 'text_config'):
20-
hf_config = hf_config.text_config
21-
elif hasattr(hf_config, 'llm_config'):
22-
hf_config = hf_config.llm_config
23-
2417
head_dim = getattr(hf_config, 'head_dim', None)
2518
head_dim = head_dim or hf_config.hidden_size // hf_config.num_attention_heads
2619

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .builder import AutoModelConfigBuilder
3+
from .default import DefaultModelConfigBuilder
4+
5+
6+
class Qwen3VLModelConfigBuilder(AutoModelConfigBuilder):
7+
8+
@classmethod
9+
def condition(cls, hf_config):
10+
"""config."""
11+
return hf_config.model_type == 'qwen3_vl' or hf_config.model_type == 'qwen3_vl_moe'
12+
13+
@classmethod
14+
def build(cls, hf_config, model_path: str = None, **kwargs):
15+
"""build."""
16+
if hasattr(hf_config, 'quantization_config') and not hasattr(hf_config.text_config, 'quantization_config'):
17+
setattr(hf_config.text_config, 'quantization_config', hf_config.quantization_config)
18+
cfg = DefaultModelConfigBuilder.build(hf_config.text_config, model_path, **kwargs)
19+
setattr(hf_config, 'dtype', hf_config.text_config.dtype)
20+
cfg.hf_config = hf_config
21+
return cfg

0 commit comments

Comments
 (0)