Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 122 additions & 11 deletions paddleformers/transformers/qwen2/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,18 @@ def _get_tensor_parallel_mappings(cls, config: Qwen2Config, is_split=True):
"self_attn.q_proj.weight",
"self_attn.k_proj.weight",
"self_attn.v_proj.weight",
]
FUSE_LAYER_COLWISE = [
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

全缺少 is_naive_2fuse

"self_attn.qkv_proj.weight",
]

FFN_LAYER_COLWISE = [
"mlp.up_proj.weight",
"mlp.gate_proj.weight",
]
FUSE_FFN_LAYER_COLWISE = [
"mlp.up_gate_proj.weight",
]

LAYER_ROWWISE = ["self_attn.o_proj.weight", "mlp.down_proj.weight"]

Expand All @@ -332,35 +341,136 @@ def _get_tensor_parallel_mappings(cls, config: Qwen2Config, is_split=True):
"self_attn.k_proj.bias",
"self_attn.v_proj.bias",
]
FUSE_BIAS_KEYS = [
"self_attn.qkv_proj.bias",
]

def make_base_actions():
actions = {
"lm_head.weight": partial(fn, is_column=False),
"embed_tokens.weight": partial(fn, is_column=False),
f"{cls.base_model_prefix}.embed_tokens.weight": partial(fn, is_column=False),
}
for layer_idx in range(config.num_hidden_layers):
actions.update(
{
f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True)
for k in LAYER_COLWISE
}
)
# colwise
if not config.fuse_attention_qkv:
actions.update(
{
f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True)
for k in LAYER_COLWISE
}
)
actions.update(
{
f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True)
for b in BIAS_KEYS
}
)
else:
actions.update(
{
f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True)
for k in FUSE_LAYER_COLWISE
}
)
actions.update(
{
f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True)
for b in FUSE_BIAS_KEYS
}
)
if not config.fuse_attention_ffn:
actions.update(
{
f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=True)
for k in FFN_LAYER_COLWISE
}
)
else:
actions.update(
{
f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(
fn, is_column=True, is_naive_2fuse=True
)
for k in FUSE_FFN_LAYER_COLWISE
}
)
# rowwise
actions.update(
{
f"{cls.base_model_prefix}.layers.{layer_idx}.{k}": partial(fn, is_column=False)
for k in LAYER_ROWWISE
}
)
# bias
actions.update(
{f"{cls.base_model_prefix}.layers.{layer_idx}.{b}": partial(fn, is_column=True) for b in BIAS_KEYS}
)

return actions

mappings = make_base_actions()
return mappings

@classmethod
def _get_fuse_or_split_param_mappings(cls, config: Qwen2Config, is_fuse=False):
# return parameter fuse utils
from ..conversion_utils import split_or_fuse_func

fn = split_or_fuse_func(is_fuse=is_fuse)

# last key is fused key, other keys are to be fused.
fuse_qkv_keys = [
(
"layers.0.self_attn.q_proj.weight",
"layers.0.self_attn.k_proj.weight",
"layers.0.self_attn.v_proj.weight",
"layers.0.self_attn.qkv_proj.weight",
),
(
"layers.0.self_attn.q_proj.bias",
"layers.0.self_attn.k_proj.bias",
"layers.0.self_attn.v_proj.bias",
"layers.0.self_attn.qkv_proj.bias",
),
]

fuse_gate_up_keys = (
"layers.0.mlp.gate_proj.weight",
"layers.0.mlp.up_proj.weight",
"layers.0.mlp.up_gate_proj.weight",
)
num_heads = config.num_attention_heads
num_key_value_heads = getattr(config, "num_key_value_heads", num_heads)
fuse_attention_qkv = getattr(config, "fuse_attention_qkv", False)
fuse_attention_ffn = getattr(config, "fuse_attention_ffn", False)

final_actions = {}
if is_fuse:
if fuse_attention_qkv:
for i in range(config.num_hidden_layers):
for fuse_keys in fuse_qkv_keys:
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys])
final_actions[keys] = partial(
fn, is_qkv=True, num_heads=num_heads, num_key_value_heads=num_key_value_heads
)
if fuse_attention_ffn:
for i in range(config.num_hidden_layers):
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys])
final_actions[keys] = fn
else:
if not fuse_attention_qkv:
for i in range(config.num_hidden_layers):
for fuse_keys in fuse_qkv_keys:
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_keys])
final_actions[keys] = partial(
fn,
split_nums=3,
is_qkv=True,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
)
if not fuse_attention_ffn:
for i in range(config.num_hidden_layers):
keys = tuple([key.replace("layers.0.", f"layers.{i}.") for key in fuse_gate_up_keys])
final_actions[keys] = partial(fn, split_nums=2)
return final_actions

@classmethod
def _gen_aoa_config(cls, config: Qwen2Config):
model_prefix = "" if cls == cls.base_model_class else "model."
Expand Down Expand Up @@ -1025,6 +1135,7 @@ class Qwen2ForCausalLMPipe(GeneralModelForCausalLMPipe):
config_class = Qwen2Config
_decoder_layer_cls = Qwen2DecoderLayer
_get_tensor_parallel_mappings = Qwen2Model._get_tensor_parallel_mappings
_get_fuse_or_split_param_mappings = Qwen2Model._get_fuse_or_split_param_mappings
_init_weights = Qwen2Model._init_weights
_keep_in_fp32_modules = Qwen2Model._keep_in_fp32_modules
_rotary_emb_cls = Qwen2RotaryEmbedding
Expand Down
Loading