diff --git a/paddleformers/transformers/qwen2/modeling.py b/paddleformers/transformers/qwen2/modeling.py index 856cc9b9bc..dfc3c5e09d 100644 --- a/paddleformers/transformers/qwen2/modeling.py +++ b/paddleformers/transformers/qwen2/modeling.py @@ -321,9 +321,18 @@ def _get_tensor_parallel_mappings(cls, config: Qwen2Config, is_split=True): "self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight", + ] + FUSE_LAYER_COLWISE = [ + "self_attn.qkv_proj.weight", + ] + + FFN_LAYER_COLWISE = [ "mlp.up_proj.weight", "mlp.gate_proj.weight", ] + FUSE_FFN_LAYER_COLWISE = [ + "mlp.up_gate_proj.weight", + ] LAYER_ROWWISE = ["self_attn.o_proj.weight", "mlp.down_proj.weight"] @@ -332,35 +341,136 @@ def _get_tensor_parallel_mappings(cls, config: Qwen2Config, is_split=True): "self_attn.k_proj.bias", "self_attn.v_proj.bias", ] + FUSE_BIAS_KEYS = [ + "self_attn.qkv_proj.bias", + ] def make_base_actions(): actions = { "lm_head.weight": partial(fn, is_column=False), - "embed_tokens.weight": partial(fn, is_column=False), + f"{cls.base_model_prefix}.embed_tokens.weight": partial(fn, is_column=False), } for layer_idx in range(config.num_hidden_layers): - actions.update( - { - f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True) - for k in LAYER_COLWISE - } - ) + # colwise + if not config.fuse_attention_qkv: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True) + for k in LAYER_COLWISE + } + ) + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True) + for b in BIAS_KEYS + } + ) + else: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True) + for k in FUSE_LAYER_COLWISE + } + ) + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True) + for b in FUSE_BIAS_KEYS + } + ) + if not config.fuse_attention_ffn: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True) + for k in FFN_LAYER_COLWISE + } + ) + else: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial( + fn, is_column=True, is_naive_2fuse=True + ) + for k in FUSE_FFN_LAYER_COLWISE + } + ) + # rowwise actions.update( { f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=False) for k in LAYER_ROWWISE } ) - # bias - actions.update( - {f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True) for b in BIAS_KEYS} - ) return actions mappings = make_base_actions() return mappings + @classmethod + def _get_fuse_or_split_param_mappings(cls, config: Qwen2Config, is_fuse=False): + # return parameter fuse utils + from ..conversion_utils import split_or_fuse_func + + fn = split_or_fuse_func(is_fuse=is_fuse) + + # last key is fused key, other keys are to be fused. + fuse_qkv_keys = [ + ( + "layers.0.self_attn.q_proj.weight", + "layers.0.self_attn.k_proj.weight", + "layers.0.self_attn.v_proj.weight", + "layers.0.self_attn.qkv_proj.weight", + ), + ( + "layers.0.self_attn.q_proj.bias", + "layers.0.self_attn.k_proj.bias", + "layers.0.self_attn.v_proj.bias", + "layers.0.self_attn.qkv_proj.bias", + ), + ] + + fuse_gate_up_keys = ( + "layers.0.mlp.gate_proj.weight", + "layers.0.mlp.up_proj.weight", + "layers.0.mlp.up_gate_proj.weight", + ) + num_heads = config.num_attention_heads + num_key_value_heads = getattr(config, "num_key_value_heads", num_heads) + fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False) + fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False) + + final_actions = {} + if is_fuse: + if fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for fuse_keys in fuse_qkv_keys: + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys]) + final_actions[keys] = partial( + fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + if fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = fn + else: + if not fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for fuse_keys in fuse_qkv_keys: + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys]) + final_actions[keys] = partial( + fn, + split_nums=3, + is_qkv=True, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + ) + if not fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = partial(fn, split_nums=2) + return final_actions + @classmethod def _gen_aoa_config(cls, config: Qwen2Config): model_prefix = "" if cls == cls.base_model_class else "model." @@ -1025,6 +1135,7 @@ class Qwen2ForCausalLMPipe(GeneralModelForCausalLMPipe): config_class = Qwen2Config _decoder_layer_cls = Qwen2DecoderLayer _get_tensor_parallel_mappings = Qwen2Model._get_tensor_parallel_mappings + _get_fuse_or_split_param_mappings = Qwen2Model._get_fuse_or_split_param_mappings _init_weights = Qwen2Model._init_weights _keep_in_fp32_modules = Qwen2Model._keep_in_fp32_modules _rotary_emb_cls = Qwen2RotaryEmbedding diff --git a/paddleformers/transformers/qwen2_moe/modeling.py b/paddleformers/transformers/qwen2_moe/modeling.py index 5ee6805a73..11e5b7a35a 100644 --- a/paddleformers/transformers/qwen2_moe/modeling.py +++ b/paddleformers/transformers/qwen2_moe/modeling.py @@ -531,6 +531,9 @@ def _get_tensor_parallel_mappings(cls, config: Qwen2MoeConfig, is_split=True): "self_attn.k_proj.weight", "self_attn.v_proj.weight", ] + FUSE_LAYER_COLWISE = [ + "self_attn.qkv_proj.weight", + ] LAYER_ROWWISE = ["self_attn.o_proj.weight"] @@ -538,6 +541,9 @@ def _get_tensor_parallel_mappings(cls, config: Qwen2MoeConfig, is_split=True): "up_proj.weight", "gate_proj.weight", ] + FUSE_EXPERT_LAYER_COLWISE = [ + "up_gate_proj.weight", + ] EXPERT_LAYER_ROWWISE = ["down_proj.weight"] @@ -545,6 +551,9 @@ def _get_tensor_parallel_mappings(cls, config: Qwen2MoeConfig, is_split=True): "up_proj.weight", "gate_proj.weight", ] + FUSE_SHARED_EXPERT_LAYER_COLWISE = [ + "up_gate_proj.weight", + ] SHARED_EXPERT_LAYER_ROWWISE = ["down_proj.weight"] @@ -553,32 +562,89 @@ def _get_tensor_parallel_mappings(cls, config: Qwen2MoeConfig, is_split=True): "self_attn.k_proj.bias", "self_attn.v_proj.bias", ] + FUSE_BIAS_KEYS = [ + "self_attn.qkv_proj.bias", + ] def make_base_actions(): actions = { "lm_head.weight": partial(fn, is_column=False), - "embed_tokens.weight": partial(fn, is_column=False), + f"{cls.base_model_prefix}.embed_tokens.weight": partial(fn, is_column=False), } for layer_idx in range(config.num_hidden_layers): - actions.update( - { - f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True) - for k in LAYER_COLWISE - } - ) + # colwise + if not config.fuse_attention_qkv: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True) + for k in LAYER_COLWISE + } + ) + if config.qkv_bias: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True) + for b in BIAS_KEYS + } + ) + else: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True) + for k in FUSE_LAYER_COLWISE + } + ) + if config.qkv_bias: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True) + for b in FUSE_BIAS_KEYS + } + ) + + if not config.fuse_attention_ffn: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial( + fn, is_column=True + ) + for e in range(config.num_experts) + for k in EXPERT_LAYER_COLWISE + } + ) + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.shared_expert.{k}": partial( + fn, is_column=True + ) + for k in SHARED_EXPERT_LAYER_COLWISE + } + ) + else: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial( + fn, is_column=True, is_naive_2fuse=True + ) + for e in range(config.num_experts) + for k in FUSE_EXPERT_LAYER_COLWISE + } + ) + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.shared_expert.{k}": partial( + fn, is_column=True, is_naive_2fuse=True + ) + for k in FUSE_SHARED_EXPERT_LAYER_COLWISE + } + ) + # rowwise actions.update( { f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=False) for k in LAYER_ROWWISE } ) - actions.update( - { - f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial(fn, is_column=True) - for e in range(config.num_experts) - for k in EXPERT_LAYER_COLWISE - } - ) actions.update( { f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial(fn, is_column=False) @@ -586,14 +652,6 @@ def make_base_actions(): for k in EXPERT_LAYER_ROWWISE } ) - actions.update( - { - f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.shared_expert.{k}": partial( - fn, is_column=True - ) - for k in SHARED_EXPERT_LAYER_COLWISE - } - ) actions.update( { f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.shared_expert.{k}": partial( @@ -602,20 +660,78 @@ def make_base_actions(): for k in SHARED_EXPERT_LAYER_ROWWISE } ) - # bias - if config.qkv_bias: - actions.update( - { - f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True) - for b in BIAS_KEYS - } - ) return actions mappings = make_base_actions() return mappings + @classmethod + def _get_fuse_or_split_param_mappings(cls, config: Qwen2MoeConfig, is_fuse=False): + # return parameter fuse utils + from ..conversion_utils import split_or_fuse_func + + fn = split_or_fuse_func(is_fuse=is_fuse) + + # last key is fused key, other keys are to be fused. + fuse_qkv_keys = [ + ( + "layers.0.self_attn.q_proj.weight", + "layers.0.self_attn.k_proj.weight", + "layers.0.self_attn.v_proj.weight", + "layers.0.self_attn.qkv_proj.weight", + ), + ( + "layers.0.self_attn.q_proj.bias", + "layers.0.self_attn.k_proj.bias", + "layers.0.self_attn.v_proj.bias", + "layers.0.self_attn.qkv_proj.bias", + ), + ] + + fuse_gate_up_keys = ( + "layers.0.mlp.experts.0.gate_proj.weight", + "layers.0.mlp.experts.0.up_proj.weight", + "layers.0.mlp.experts.0.up_gate_proj.weight", + ) + num_heads = config.num_attention_heads + num_key_value_heads = getattr(config, "num_key_value_heads", num_heads) + fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False) + fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False) + num_experts = getattr(config, "num_experts", 128) + + final_actions = {} + if is_fuse: + if fuse_attention_qkv: + for i in range(config.num_hidden_layers): + keys = [key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys] + for j in range(num_experts): + experts_keys = tuple([key.replace("experts.0.", f"experts.{j}.") for key in keys]) + final_actions[experts_keys] = fn + if fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = fn + else: + if not fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for fuse_keys in fuse_qkv_keys: + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys]) + final_actions[keys] = partial( + fn, + split_nums=3, + is_qkv=True, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + ) + if not fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = [key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys] + for j in range(num_experts): + experts_keys = tuple([key.replace("experts.0.", f"experts.{j}.") for key in keys]) + final_actions[experts_keys] = partial(fn, split_nums=2) + return final_actions + @classmethod def _gen_aoa_config(cls, config: Qwen2MoeConfig): model_prefix = "" if cls == cls.base_model_class else "model." @@ -1099,6 +1215,7 @@ class Qwen2MoeForCausalLMPipe(GeneralModelForCausalLMPipe): config_class = Qwen2MoeConfig _decoder_layer_cls = Qwen2MoeDecoderLayer _get_tensor_parallel_mappings = Qwen2MoeModel._get_tensor_parallel_mappings + _get_fuse_or_split_param_mappings = Qwen2MoeModel._get_fuse_or_split_param_mappings _init_weights = Qwen2MoeModel._init_weights _keep_in_fp32_modules = Qwen2MoeModel._keep_in_fp32_modules _rotary_emb_cls = Qwen2MoeRotaryEmbedding diff --git a/paddleformers/transformers/qwen3/modeling.py b/paddleformers/transformers/qwen3/modeling.py index cefddb34c8..987ae0deb3 100644 --- a/paddleformers/transformers/qwen3/modeling.py +++ b/paddleformers/transformers/qwen3/modeling.py @@ -337,9 +337,18 @@ def _get_tensor_parallel_mappings(cls, config: Qwen3Config, is_split=True): "self_attn.q_proj.weight", "self_attn.k_proj.weight", "self_attn.v_proj.weight", + ] + FUSE_LAYER_COLWISE = [ + "self_attn.qkv_proj.weight", + ] + + FFN_LAYER_COLWISE = [ "mlp.up_proj.weight", "mlp.gate_proj.weight", ] + FUSE_FFN_LAYER_COLWISE = [ + "mlp.up_gate_proj.weight", + ] LAYER_ROWWISE = ["self_attn.o_proj.weight", "mlp.down_proj.weight"] @@ -348,39 +357,130 @@ def _get_tensor_parallel_mappings(cls, config: Qwen3Config, is_split=True): "self_attn.k_proj.bias", "self_attn.v_proj.bias", ] + FUSE_BIAS_KEYS = [ + "self_attn.qkv_proj.bias", + ] def make_base_actions(): actions = { "lm_head.weight": partial(fn, is_column=False), - "embed_tokens.weight": partial(fn, is_column=False), + f"{cls.base_model_prefix}.embed_tokens.weight": partial(fn, is_column=False), } for layer_idx in range(config.num_hidden_layers): - actions.update( - { - f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True) - for k in LAYER_COLWISE - } - ) + # colwise + if not config.fuse_attention_qkv: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True) + for k in LAYER_COLWISE + } + ) + # bias + if config.attention_bias: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True) + for b in BIAS_KEYS + } + ) + else: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True) + for k in FUSE_LAYER_COLWISE + } + ) + # bias + if config.attention_bias: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True) + for b in FUSE_BIAS_KEYS + } + ) + if not config.fuse_attention_ffn: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True) + for k in FFN_LAYER_COLWISE + } + ) + else: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial( + fn, is_column=True, is_naive_2fuse=True + ) + for k in FUSE_FFN_LAYER_COLWISE + } + ) + # rowwise actions.update( { f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=False) for k in LAYER_ROWWISE } ) - # bias - if config.attention_bias: - actions.update( - { - f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True) - for b in BIAS_KEYS - } - ) return actions mappings = make_base_actions() return mappings + @classmethod + def _get_fuse_or_split_param_mappings(cls, config: Qwen3Config, is_fuse=False): + # return parameter fuse utils + from ..conversion_utils import split_or_fuse_func + + fn = split_or_fuse_func(is_fuse=is_fuse) + + # last key is fused key, other keys are to be fused. + fuse_qkv_keys = [ + ( + "layers.0.self_attn.q_proj.weight", + "layers.0.self_attn.k_proj.weight", + "layers.0.self_attn.v_proj.weight", + "layers.0.self_attn.qkv_proj.weight", + ) + ] + + fuse_gate_up_keys = ( + "layers.0.mlp.gate_proj.weight", + "layers.0.mlp.up_proj.weight", + "layers.0.mlp.up_gate_proj.weight", + ) + num_heads = config.num_attention_heads + num_key_value_heads = getattr(config, "num_key_value_heads", num_heads) + fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False) + fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False) + + final_actions = {} + if is_fuse: + if fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for fuse_keys in fuse_qkv_keys: + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys]) + final_actions[keys] = partial( + fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + if fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = fn + else: + if not fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for fuse_keys in fuse_qkv_keys: + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys]) + final_actions[keys] = partial( + fn, split_nums=3, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + if not fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys]) + final_actions[keys] = partial(fn, split_nums=2) + return final_actions + @classmethod def _gen_aoa_config(cls, config: Qwen3Config): model_prefix = "" if cls == cls.base_model_class else "model." @@ -1051,6 +1151,7 @@ class Qwen3ForCausalLMPipe(GeneralModelForCausalLMPipe): config_class = Qwen3Config _decoder_layer_cls = Qwen3DecoderLayer _get_tensor_parallel_mappings = Qwen3Model._get_tensor_parallel_mappings + _get_fuse_or_split_param_mappings = Qwen3Model._get_fuse_or_split_param_mappings _init_weights = Qwen3Model._init_weights _keep_in_fp32_modules = Qwen3Model._keep_in_fp32_modules _rotary_emb_cls = Qwen3RotaryEmbedding diff --git a/paddleformers/transformers/qwen3_moe/modeling.py b/paddleformers/transformers/qwen3_moe/modeling.py index ca834848c7..4c407a00c0 100644 --- a/paddleformers/transformers/qwen3_moe/modeling.py +++ b/paddleformers/transformers/qwen3_moe/modeling.py @@ -562,6 +562,9 @@ def _get_tensor_parallel_mappings(cls, config: Qwen3MoeConfig, is_split=True): "self_attn.k_proj.weight", "self_attn.v_proj.weight", ] + FUSE_LAYER_COLWISE = [ + "self_attn.qkv_proj.weight", + ] LAYER_ROWWISE = ["self_attn.o_proj.weight"] @@ -569,6 +572,9 @@ def _get_tensor_parallel_mappings(cls, config: Qwen3MoeConfig, is_split=True): "up_proj.weight", "gate_proj.weight", ] + FUSE_EXPERT_LAYER_COLWISE = [ + "up_gate_proj.weight", + ] EXPERT_LAYER_ROWWISE = ["down_proj.weight"] @@ -577,25 +583,52 @@ def _get_tensor_parallel_mappings(cls, config: Qwen3MoeConfig, is_split=True): "self_attn.k_proj.bias", "self_attn.v_proj.bias", ] + FUSE_BIAS_KEYS = [ + "self_attn.qkv_proj.bias", + ] def make_base_actions(): actions = { "lm_head.weight": partial(fn, is_column=False), - "embed_tokens.weight": partial(fn, is_column=False), + f"{cls.base_model_prefix}.embed_tokens.weight": partial(fn, is_column=False), } for layer_idx in range(config.num_hidden_layers): - actions.update( - { - f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True) - for k in LAYER_COLWISE - } - ) + if not config.fuse_attention_qkv: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True) + for k in LAYER_COLWISE + } + ) + if config.attention_bias: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True) + for b in BIAS_KEYS + } + ) + else: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True) + for k in FUSE_LAYER_COLWISE + } + ) + if config.attention_bias: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True) + for b in FUSE_BIAS_KEYS + } + ) + actions.update( { f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=False) for k in LAYER_ROWWISE } ) + try: moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group() except Exception: @@ -605,15 +638,26 @@ def make_base_actions(): if expert_parallel_degree <= 1: # # if disable_ffn_model_parallel is True, disable expert layer tp plan # if not config.disable_ffn_model_parallel: - actions.update( - { - f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial( - fn, is_column=True - ) - for e in range(config.num_experts) - for k in EXPERT_LAYER_COLWISE - } - ) + if not config.fuse_attention_ffn: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial( + fn, is_column=True + ) + for e in range(config.num_experts) + for k in EXPERT_LAYER_COLWISE + } + ) + else: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial( + fn, is_column=True, is_naive_2fuse=True + ) + for e in range(config.num_experts) + for k in FUSE_EXPERT_LAYER_COLWISE + } + ) actions.update( { f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial( @@ -623,32 +667,75 @@ def make_base_actions(): for k in EXPERT_LAYER_ROWWISE } ) - actions.update( - { - f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.{k}": partial(fn, is_column=False) - for k in EXPERT_LAYER_ROWWISE - } - ) - actions.update( - { - f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.{k}": partial(fn, is_column=True) - for k in EXPERT_LAYER_COLWISE - } - ) - # bias - if config.attention_bias: - actions.update( - { - f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True) - for b in BIAS_KEYS - } - ) return actions mappings = make_base_actions() return mappings + @classmethod + def _get_fuse_or_split_param_mappings(cls, config: Qwen3MoeConfig, is_fuse=False): + # return parameter fuse utils + from ..conversion_utils import split_or_fuse_func + + fn = split_or_fuse_func(is_fuse=is_fuse) + + # last key is fused key, other keys are to be fused. + fuse_qkv_keys = [ + ( + "layers.0.self_attn.q_proj.weight", + "layers.0.self_attn.k_proj.weight", + "layers.0.self_attn.v_proj.weight", + "layers.0.self_attn.qkv_proj.weight", + ), + ] + + fuse_gate_up_keys = ( + "layers.0.mlp.experts.0.gate_proj.weight", + "layers.0.mlp.experts.0.up_proj.weight", + "layers.0.mlp.experts.0.up_gate_proj.weight", + ) + num_heads = config.num_attention_heads + num_key_value_heads = getattr(config, "num_key_value_heads", num_heads) + fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False) + fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False) + num_experts = getattr(config, "num_experts", 128) + + final_actions = {} + if is_fuse: + if fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for fuse_keys in fuse_qkv_keys: + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys]) + final_actions[keys] = partial( + fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads + ) + if fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = [key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys] + for j in range(num_experts): + experts_keys = tuple([key.replace("experts.0.", f"experts.{j}.") for key in keys]) + final_actions[experts_keys] = fn + else: + if not fuse_attention_qkv: + for i in range(config.num_hidden_layers): + for fuse_keys in fuse_qkv_keys: + keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys]) + final_actions[keys] = partial( + fn, + split_nums=3, + is_qkv=True, + num_heads=num_heads, + num_key_value_heads=num_key_value_heads, + ) + if not fuse_attention_ffn: + for i in range(config.num_hidden_layers): + keys = [key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys] + for j in range(num_experts): + experts_keys = tuple([key.replace("experts.0.", f"experts.{j}.") for key in keys]) + final_actions[experts_keys] = partial(fn, split_nums=2) + return final_actions + @classmethod def _gen_aoa_config(cls, config: Qwen3MoeConfig): model_prefix = "" if cls == cls.base_model_class else "model." @@ -1120,6 +1207,7 @@ class Qwen3MoeForCausalLMPipe(GeneralModelForCausalLMPipe): config_class = Qwen3MoeConfig _decoder_layer_cls = Qwen3MoeDecoderLayer _get_tensor_parallel_mappings = Qwen3MoeModel._get_tensor_parallel_mappings + _get_fuse_or_split_param_mappings = Qwen3MoeModel._get_fuse_or_split_param_mappings _init_weights = Qwen3MoeModel._init_weights _keep_in_fp32_modules = Qwen3MoeModel._keep_in_fp32_modules _rotary_emb_cls = Qwen3MoeRotaryEmbedding