1616
1717from __future__ import annotations
1818
19+ import re
1920from functools import partial
2021from typing import Dict , Union
2122
@@ -250,7 +251,7 @@ def __init__(
250251 self .embed_tokens = fd_config .speculative_config .sharing_model .ernie .embed_tokens
251252 self .norm = fd_config .speculative_config .sharing_model .ernie .norm
252253
253- self .layers = nn .LayerList (
254+ self .mtp_block = nn .LayerList (
254255 [
255256 Ernie4_5_DecoderLayer (
256257 fd_config = fd_config ,
@@ -296,7 +297,7 @@ def load_state_dict(self, state_dict):
296297 self .eh_proj .load_state_dict (state_dict )
297298 for i in range (self .num_layers ):
298299 logger .info (f"Start load layer { i } " )
299- self .layers [i ].load_state_dict (state_dict )
300+ self .mtp_block [i ].load_state_dict (state_dict )
300301
301302 def forward (
302303 self ,
@@ -315,7 +316,7 @@ def forward(
315316 hidden_states = self .eh_proj (inputs_embedding )
316317 residual = None
317318 for i in range (self .num_layers ):
318- hidden_states , residual = self .layers [i ](forward_meta , hidden_states , residual )
319+ hidden_states , residual = self .mtp_block [i ](forward_meta , hidden_states , residual )
319320
320321 hidden_states = hidden_states + residual
321322
@@ -374,17 +375,23 @@ def load_weights(self, weights_iterator) -> None:
374375 weights_iterator (Iterator): An iterator yielding (name, weight) pairs.
375376 """
376377
377- from fastdeploy .model_executor .utils import default_weight_loader
378+ from fastdeploy .model_executor .utils import (
379+ default_weight_loader ,
380+ process_weights_after_loading ,
381+ )
378382
379383 all_param_mapping = [
380384 # (param_name, weight_name, expert_id, shard_id)
381385 ("embed_tokens.embeddings" , "embed_tokens" , None , None ),
382386 ("lm_head.linear" , "lm_head" , None , None ),
387+ ("enorm" , "mtp_emb_norm.0" , None , None ),
388+ ("hnorm" , "mtp_hidden_norm.0" , None , None ),
389+ ("eh_proj.linear" , "mtp_linear_proj.0" , None , None ),
383390 ]
384391
385392 params_dict = dict (self .named_parameters ())
386393 shard_id = None
387-
394+ process_weights_after_loading_fn = process_weights_after_loading ( dict ( self . named_sublayers ()))
388395 for loaded_weight_name , loaded_weight in weights_iterator :
389396 for param_name , weight_name , exp_id , shard_id in all_param_mapping :
390397 if weight_name not in loaded_weight_name :
@@ -396,11 +403,16 @@ def load_weights(self, weights_iterator) -> None:
396403 else :
397404 if loaded_weight_name not in params_dict .keys ():
398405 continue
406+ model_param_name = loaded_weight_name
399407 param = params_dict [loaded_weight_name ]
400408
401409 # Get weight loader from parameter and set weight
402410 weight_loader = getattr (param , "weight_loader" , default_weight_loader (self .fd_config ))
403411 weight_loader (param , loaded_weight )
412+ model_sublayer_name = re .sub (
413+ r"\.(up_gate_proj_weight|down_proj_weight|weight|cache_k_scale|cache_v_scale)$" , "" , model_param_name
414+ )
415+ process_weights_after_loading_fn (model_sublayer_name , param )
404416
405417 def compute_logits (self , hidden_states : paddle .Tensor ):
406418 """
0 commit comments