From d28e26ed06b728d433dfef1ad602279ca510519a Mon Sep 17 00:00:00 2001 From: Tyler Poon Date: Sat, 19 Oct 2024 09:08:14 -0700 Subject: [PATCH] ADLR/megatron-lm!2227 - qwen2.5 conversion --- tools/checkpoint/loader_llama_mistral.py | 28 +++++++++++++++++++----- tools/checkpoint/saver_mcore.py | 8 +++++-- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/tools/checkpoint/loader_llama_mistral.py b/tools/checkpoint/loader_llama_mistral.py index ea803c5543..0667fad522 100644 --- a/tools/checkpoint/loader_llama_mistral.py +++ b/tools/checkpoint/loader_llama_mistral.py @@ -19,9 +19,9 @@ def add_arguments(parser): # TODO(jbarker): Need assertion to make sure *exactly* one of these is used parser.add_argument('--model-size', type=str, required=True, - choices=['llama2-7B', 'llama2-13B', 'llama2-70B', 'llama2-7Bf', 'llama2-13Bf', 'llama2-70Bf', 'llama3-8B', 'llama3-70B', 'llama3-8Bf', 'llama3-70Bf', 'mistral-7B', 'mistral-7Bf', 'yi-34B'], - help='Model size can be `llama2-7B`, `llama2-13B`, `llama2-70B`, `llama3-8B`, `llama3-70B`, `mistral-7B` (for pretrained models), ' - 'and `llama2-7Bf`, `llama2-13Bf`, `llama2-70Bf`, `llama3-8Bf`, `llama3-70bf` and `mistral-7Bf` (for chat-finetuned models).') + choices=['llama2-7B', 'llama2-13B', 'llama2-70B', 'llama2-7Bf', 'llama2-13Bf', 'llama2-70Bf', 'llama3-8B', 'llama3-70B', 'llama3-8Bf', 'llama3-70Bf', 'mistral-7B', 'mistral-7Bf', 'yi-34B', 'qwen2.5-7B', 'qwen2.5-72B', 'qwen2.5-7Bf', 'qwen2.5-72Bf'], + help='Model size can be `llama2-7B`, `llama2-13B`, `llama2-70B`, `llama3-8B`, `llama3-70B`, `mistral-7B`, `qwen2.5-7B`, `qwen2.5-72B` (for pretrained models), ' + 'and `llama2-7Bf`, `llama2-13Bf`, `llama2-70Bf`, `llama3-8Bf`, `llama3-70bf`, `mistral-7Bf`, `qwen2.5-7Bf`, and `qwen2.5-72Bf` (for chat-finetuned models).') parser.add_argument('--checkpoint-type', type=str, required=True, help='Type of checkpoint to convert, options are "meta" or "hf"') parser.add_argument('--bf16', action='store_true', help='Whether to load weights in bf16.') @@ -59,6 +59,10 @@ def verify_transformers_version(): "mistral-7B": 1, "mistral-7Bf": 1, "yi-34B": 8, + "qwen2.5-7B": 1, + "qwen2.5-7Bf": 1, + "qwen2.5-72B": 8, + "qwen2.5-72Bf": 8, } @@ -353,6 +357,13 @@ def set_attn_state(args, layer, hf_layer): hf_attn.k_proj.weight.reshape((ng, dim, -1)), hf_attn.v_proj.weight.reshape((ng, dim, -1)), ], dim=1).reshape((-1, args.hidden_size))) + if args.add_qkv_bias: + attn.query_key_value.bias.data.copy_(torch.cat([ + hf_attn.q_proj.bias.reshape((ng, dim*nh//ng)), + hf_attn.k_proj.bias.reshape((ng, dim)), + hf_attn.v_proj.bias.reshape((ng, dim)), + ], dim=1).reshape(-1)) + attn.dense.weight.data.copy_(hf_attn.o_proj.weight) @@ -458,6 +469,9 @@ def _load_checkpoint(queue, args): margs.tokenizer_type = "HuggingFaceTokenizer" elif "mistral" in args.model_size: margs.tokenizer_type = "HuggingFaceTokenizer" + elif "qwen2.5" in args.model_size: + margs.tokenizer_type = "HuggingFaceTokenizer" + margs.add_qkv_bias = True # Arguments do sanity checks on the world size, but we don't care, # so trick it into thinking we are plenty of processes. @@ -530,6 +544,7 @@ def check_for_arg(arg_name, default=None): md.output_layer = margs.untie_embeddings_and_output_weights md.position_embedding_type = margs.position_embedding_type md.linear_bias = margs.add_bias_linear + md.qkv_bias = margs.add_qkv_bias md.norm_has_bias = False md.swiglu = margs.swiglu md.previous_tensor_parallel_size = margs.tensor_model_parallel_size @@ -591,8 +606,10 @@ def queue_put(name, msg): dense_weight.append(layer.self_attention.dense.weight.data) mlp_l0_weight.append(layer.mlp.dense_h_to_4h.weight.data) mlp_l1_weight.append(layer.mlp.dense_4h_to_h.weight.data) - if md.linear_bias: + + if md.qkv_bias: qkv_bias.append(layer.self_attention.query_key_value.bias.data) + if md.linear_bias: mlp_l0_bias.append(layer.mlp.dense_h_to_4h.bias.data) # Handle gated linear units. @@ -609,8 +626,9 @@ def queue_put(name, msg): message["qkv weight"] = torch.cat(qkv_weight, dim=0) message["dense weight"] = torch.cat(dense_weight, dim=1) message["mlp l1 weight"] = torch.cat(mlp_l1_weight, dim=1) - if md.linear_bias: + if md.qkv_bias: message["qkv bias"] = torch.cat(qkv_bias, dim=0) + if md.linear_bias: if md.swiglu: for tp_rank in range(tp_size): mlp_l0_bias[tp_rank] = torch.chunk(mlp_l0_bias[tp_rank], 2, dim=0) diff --git a/tools/checkpoint/saver_mcore.py b/tools/checkpoint/saver_mcore.py index 6aec90e41b..e1779b8969 100644 --- a/tools/checkpoint/saver_mcore.py +++ b/tools/checkpoint/saver_mcore.py @@ -628,10 +628,11 @@ def chunk_bias(bias, parallel_mode, tp_size=1, ep_size=1): else: mlp_l0_weight = chunk_weight(msg.pop("mlp l0 weight"), "column", args.target_tensor_parallel_size, args.target_expert_parallel_size) + if md.qkv_bias: + qkv_bias = chunk_bias(msg.pop("qkv bias"), 'column', args.target_tensor_parallel_size) if md.linear_bias: dense_bias = msg.pop("dense bias") mlp_l1_bias = chunk_bias(msg.pop("mlp l1 bias"), 'row', args.target_tensor_parallel_size, args.target_expert_parallel_size) - qkv_bias = chunk_bias(msg.pop("qkv bias"), 'column', args.target_tensor_parallel_size) if md.swiglu: mlp_l0_bias_W = chunk_bias(msg.pop("mlp l0 bias W"), 'column', args.target_tensor_parallel_size, args.target_expert_parallel_size) mlp_l0_bias_V = chunk_bias(msg.pop("mlp l0 bias V"), 'column', args.target_tensor_parallel_size, args.target_expert_parallel_size) @@ -662,9 +663,12 @@ def chunk_bias(bias, parallel_mode, tp_size=1, ep_size=1): "self_attn_norm_bias" : input_norm_bias if md.norm_has_bias else None, "mlp_norm_bias" : post_norm_bias if md.norm_has_bias else None, }) + if md.qkv_bias: + params_dict.update({ + "self_attn_qkv_bias" : qkv_bias[tp_rank] + }) if md.linear_bias: params_dict.update({ - "self_attn_qkv_bias" : qkv_bias[tp_rank], "self_attn_proj_bias" : dense_bias }) if margs.num_experts: