Skip to content

Commit

Permalink
Merge branch 'qwen25_conversion' into 'main'
Browse files Browse the repository at this point in the history
qwen2.5 conversion

See merge request ADLR/megatron-lm!2227
  • Loading branch information
jon-barker committed Oct 19, 2024
2 parents 739177e + d28e26e commit db7d37b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 7 deletions.
28 changes: 23 additions & 5 deletions tools/checkpoint/loader_llama_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def add_arguments(parser):

# TODO(jbarker): Need assertion to make sure *exactly* one of these is used
parser.add_argument('--model-size', type=str, required=True,
choices=['llama2-7B', 'llama2-13B', 'llama2-70B', 'llama2-7Bf', 'llama2-13Bf', 'llama2-70Bf', 'llama3-8B', 'llama3-70B', 'llama3-8Bf', 'llama3-70Bf', 'mistral-7B', 'mistral-7Bf', 'yi-34B'],
help='Model size can be `llama2-7B`, `llama2-13B`, `llama2-70B`, `llama3-8B`, `llama3-70B`, `mistral-7B` (for pretrained models), '
'and `llama2-7Bf`, `llama2-13Bf`, `llama2-70Bf`, `llama3-8Bf`, `llama3-70bf` and `mistral-7Bf` (for chat-finetuned models).')
choices=['llama2-7B', 'llama2-13B', 'llama2-70B', 'llama2-7Bf', 'llama2-13Bf', 'llama2-70Bf', 'llama3-8B', 'llama3-70B', 'llama3-8Bf', 'llama3-70Bf', 'mistral-7B', 'mistral-7Bf', 'yi-34B', 'qwen2.5-7B', 'qwen2.5-72B', 'qwen2.5-7Bf', 'qwen2.5-72Bf'],
help='Model size can be `llama2-7B`, `llama2-13B`, `llama2-70B`, `llama3-8B`, `llama3-70B`, `mistral-7B`, `qwen2.5-7B`, `qwen2.5-72B` (for pretrained models), '
'and `llama2-7Bf`, `llama2-13Bf`, `llama2-70Bf`, `llama3-8Bf`, `llama3-70bf`, `mistral-7Bf`, `qwen2.5-7Bf`, and `qwen2.5-72Bf` (for chat-finetuned models).')
parser.add_argument('--checkpoint-type', type=str, required=True,
help='Type of checkpoint to convert, options are "meta" or "hf"')
parser.add_argument('--bf16', action='store_true', help='Whether to load weights in bf16.')
Expand Down Expand Up @@ -59,6 +59,10 @@ def verify_transformers_version():
"mistral-7B": 1,
"mistral-7Bf": 1,
"yi-34B": 8,
"qwen2.5-7B": 1,
"qwen2.5-7Bf": 1,
"qwen2.5-72B": 8,
"qwen2.5-72Bf": 8,
}


Expand Down Expand Up @@ -353,6 +357,13 @@ def set_attn_state(args, layer, hf_layer):
hf_attn.k_proj.weight.reshape((ng, dim, -1)),
hf_attn.v_proj.weight.reshape((ng, dim, -1)),
], dim=1).reshape((-1, args.hidden_size)))
if args.add_qkv_bias:
attn.query_key_value.bias.data.copy_(torch.cat([
hf_attn.q_proj.bias.reshape((ng, dim*nh//ng)),
hf_attn.k_proj.bias.reshape((ng, dim)),
hf_attn.v_proj.bias.reshape((ng, dim)),
], dim=1).reshape(-1))

attn.dense.weight.data.copy_(hf_attn.o_proj.weight)


Expand Down Expand Up @@ -458,6 +469,9 @@ def _load_checkpoint(queue, args):
margs.tokenizer_type = "HuggingFaceTokenizer"
elif "mistral" in args.model_size:
margs.tokenizer_type = "HuggingFaceTokenizer"
elif "qwen2.5" in args.model_size:
margs.tokenizer_type = "HuggingFaceTokenizer"
margs.add_qkv_bias = True

# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes.
Expand Down Expand Up @@ -530,6 +544,7 @@ def check_for_arg(arg_name, default=None):
md.output_layer = margs.untie_embeddings_and_output_weights
md.position_embedding_type = margs.position_embedding_type
md.linear_bias = margs.add_bias_linear
md.qkv_bias = margs.add_qkv_bias
md.norm_has_bias = False
md.swiglu = margs.swiglu
md.previous_tensor_parallel_size = margs.tensor_model_parallel_size
Expand Down Expand Up @@ -591,8 +606,10 @@ def queue_put(name, msg):
dense_weight.append(layer.self_attention.dense.weight.data)
mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data)
mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data)
if md.linear_bias:

if md.qkv_bias:
qkv_bias.append(layer.self_attention.query_key_value.bias.data)
if md.linear_bias:
mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data)

# Handle gated linear units.
Expand All @@ -609,8 +626,9 @@ def queue_put(name, msg):
message["qkv weight"] = torch.cat(qkv_weight, dim=0)
message["dense weight"] = torch.cat(dense_weight, dim=1)
message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1)
if md.linear_bias:
if md.qkv_bias:
message["qkv bias"] = torch.cat(qkv_bias, dim=0)
if md.linear_bias:
if md.swiglu:
for tp_rank in range(tp_size):
mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0)
Expand Down
8 changes: 6 additions & 2 deletions tools/checkpoint/saver_mcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,10 +628,11 @@ def chunk_bias(bias, parallel_mode, tp_size=1, ep_size=1):
else:
mlp_l0_weight = chunk_weight(msg.pop("mlp l0 weight"), "column", args.target_tensor_parallel_size, args.target_expert_parallel_size)

if md.qkv_bias:
qkv_bias = chunk_bias(msg.pop("qkv bias"), 'column', args.target_tensor_parallel_size)
if md.linear_bias:
dense_bias = msg.pop("dense bias")
mlp_l1_bias = chunk_bias(msg.pop("mlp l1 bias"), 'row', args.target_tensor_parallel_size, args.target_expert_parallel_size)
qkv_bias = chunk_bias(msg.pop("qkv bias"), 'column', args.target_tensor_parallel_size)
if md.swiglu:
mlp_l0_bias_W = chunk_bias(msg.pop("mlp l0 bias W"), 'column', args.target_tensor_parallel_size, args.target_expert_parallel_size)
mlp_l0_bias_V = chunk_bias(msg.pop("mlp l0 bias V"), 'column', args.target_tensor_parallel_size, args.target_expert_parallel_size)
Expand Down Expand Up @@ -662,9 +663,12 @@ def chunk_bias(bias, parallel_mode, tp_size=1, ep_size=1):
"self_attn_norm_bias" : input_norm_bias if md.norm_has_bias else None,
"mlp_norm_bias" : post_norm_bias if md.norm_has_bias else None,
})
if md.qkv_bias:
params_dict.update({
"self_attn_qkv_bias" : qkv_bias[tp_rank]
})
if md.linear_bias:
params_dict.update({
"self_attn_qkv_bias" : qkv_bias[tp_rank],
"self_attn_proj_bias" : dense_bias
})
if margs.num_experts:
Expand Down

0 comments on commit db7d37b

Please sign in to comment.