|
| 1 | +import argparse |
| 2 | +import os |
| 3 | + |
| 4 | +import torch |
| 5 | +from transformers import AutoModel |
| 6 | + |
| 7 | + |
| 8 | +def convert(model_name, output_path, tensor_parallel_size, use_te): |
| 9 | + """Convert InternViT HF checkpoint to mcore.""" |
| 10 | + hf_model = AutoModel.from_pretrained( |
| 11 | + model_name, |
| 12 | + trust_remote_code=True |
| 13 | + ) |
| 14 | + |
| 15 | + hf_state_dict = hf_model.state_dict() |
| 16 | + new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)] |
| 17 | + |
| 18 | + hidden_size = 3200 |
| 19 | + num_heads = 25 |
| 20 | + dim = 128 |
| 21 | + |
| 22 | + order = torch.ones(3 * hidden_size).long() |
| 23 | + |
| 24 | + for j in range(num_heads): |
| 25 | + for i in range(dim): |
| 26 | + order[i + dim*3*j] = j*dim+i |
| 27 | + order[dim + i + dim*3*j] = j*dim+i+num_heads*dim |
| 28 | + order[dim*2 + i + dim*3*j] = j*dim+i+num_heads*dim*2 |
| 29 | + |
| 30 | + for name, tensor in hf_state_dict.items(): |
| 31 | + # Map parameter names to ones used in megatron. |
| 32 | + new_name = "" |
| 33 | + new_tensor = tensor |
| 34 | + |
| 35 | + # This is used for chunking some tensors to target tensor parallel size. |
| 36 | + chunk_dim = None |
| 37 | + |
| 38 | + if "embeddings.class_embedding" in name: |
| 39 | + new_name = "class_token" |
| 40 | + elif "embeddings.patch_embedding.weight" in name: |
| 41 | + new_name = "conv1.weight" |
| 42 | + elif "embeddings.patch_embedding.bias" in name: |
| 43 | + new_name = "conv1.bias" |
| 44 | + elif "embeddings.position_embedding" in name: |
| 45 | + new_name = "position_embeddings.weight" |
| 46 | + new_tensor = new_tensor.squeeze(0) |
| 47 | + elif "encoder.layers" in name: |
| 48 | + layer_idx = name.split(".")[2] |
| 49 | + |
| 50 | + base = f"decoder.layers.{layer_idx}" |
| 51 | + |
| 52 | + head_dim = 128 |
| 53 | + |
| 54 | + if tensor_parallel_size == 1: |
| 55 | + num_padded_heads = 25 |
| 56 | + elif tensor_parallel_size == 8: |
| 57 | + # Note: 25 is not divisible by 8 and we don't currently support uneven heads split with tensor parallelism. |
| 58 | + # So we pad with dummy all-zero heads. Please use a nice even number of attention heads in your model. |
| 59 | + num_padded_heads = 32 |
| 60 | + else: |
| 61 | + raise NotImplementedError("invalid tensor parallel size value:", tensor_parallel_size) |
| 62 | + |
| 63 | + if "ls1" in name: |
| 64 | + new_name = f"{base}.ls1" |
| 65 | + elif "ls2" in name: |
| 66 | + new_name = f"{base}.ls2" |
| 67 | + elif "attn.qkv.weight" in name: |
| 68 | + new_name = f"{base}.self_attention.linear_qkv.weight" |
| 69 | + num_tensors = 3 |
| 70 | + padded_dim = head_dim * num_padded_heads * num_tensors |
| 71 | + padded_tensor = torch.zeros((padded_dim, new_tensor.shape[-1]), dtype=new_tensor.dtype, device=new_tensor.device) |
| 72 | + padded_tensor[:new_tensor.shape[0], :] = new_tensor[order] |
| 73 | + new_tensor = padded_tensor |
| 74 | + chunk_dim = 0 |
| 75 | + elif "attn.q_norm.weight" in name: |
| 76 | + new_name = f"{base}.self_attention.q_layernorm.weight" |
| 77 | + num_tensors = 1 |
| 78 | + padded_dim = head_dim * num_padded_heads * num_tensors |
| 79 | + padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device) |
| 80 | + padded_tensor[:new_tensor.shape[0]] = new_tensor |
| 81 | + new_tensor = padded_tensor |
| 82 | + chunk_dim = 0 |
| 83 | + elif "attn.k_norm.weight" in name: |
| 84 | + new_name = f"{base}.self_attention.k_layernorm.weight" |
| 85 | + num_tensors = 1 |
| 86 | + padded_dim = head_dim * num_padded_heads * num_tensors |
| 87 | + padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device) |
| 88 | + padded_tensor[:new_tensor.shape[0]] = new_tensor |
| 89 | + new_tensor = padded_tensor |
| 90 | + chunk_dim = 0 |
| 91 | + elif "attn.proj.weight" in name: |
| 92 | + new_name = f"{base}.self_attention.linear_proj.weight" |
| 93 | + num_tensors = 1 |
| 94 | + padded_dim = head_dim * num_padded_heads * num_tensors |
| 95 | + padded_tensor = torch.zeros((new_tensor.shape[0], padded_dim), dtype=new_tensor.dtype, device=new_tensor.device) |
| 96 | + padded_tensor[:, :new_tensor.shape[-1]] = new_tensor |
| 97 | + new_tensor = padded_tensor |
| 98 | + chunk_dim = 1 |
| 99 | + elif "attn.proj.bias" in name: |
| 100 | + new_name = f"{base}.self_attention.linear_proj.bias" |
| 101 | + elif "mlp.fc1.weight" in name: |
| 102 | + new_name = f"{base}.mlp.linear_fc1.weight" |
| 103 | + chunk_dim = 0 |
| 104 | + elif "mlp.fc1.bias" in name: |
| 105 | + new_name = f"{base}.mlp.linear_fc1.bias" |
| 106 | + chunk_dim = 0 |
| 107 | + elif "mlp.fc2.weight" in name: |
| 108 | + new_name = f"{base}.mlp.linear_fc2.weight" |
| 109 | + chunk_dim = 1 |
| 110 | + elif "mlp.fc2.bias" in name: |
| 111 | + new_name = f"{base}.mlp.linear_fc2.bias" |
| 112 | + elif "norm1" in name: |
| 113 | + new_name = f"{base}.input_layernorm.weight" |
| 114 | + elif "norm2" in name: |
| 115 | + new_name = f"{base}.pre_mlp_layernorm.weight" |
| 116 | + else: |
| 117 | + raise RuntimeError("unexpected transformer layer name", name) |
| 118 | + else: |
| 119 | + raise RuntimeError("unexpected layer name", name) |
| 120 | + |
| 121 | + assert new_name != "", f"unexpected layer name {name}" |
| 122 | + |
| 123 | + # TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility. |
| 124 | + extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2") |
| 125 | + is_extra_state_layer = any([l in new_name for l in extra_state_layers]) |
| 126 | + if use_te and is_extra_state_layer: |
| 127 | + layer = new_name.split(".")[-2] |
| 128 | + if layer in extra_state_layers: |
| 129 | + extra_state_name = ( |
| 130 | + new_name[: new_name.rfind(".") + 1] + "_extra_state" |
| 131 | + ) # Replace the weight name. |
| 132 | + for i in range(tensor_parallel_size): |
| 133 | + new_state_dicts[i]["model"][extra_state_name] = None |
| 134 | + |
| 135 | + if chunk_dim is None: |
| 136 | + new_tensors = [new_tensor for _ in range(tensor_parallel_size)] |
| 137 | + else: |
| 138 | + new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim) |
| 139 | + |
| 140 | + for i in range(tensor_parallel_size): |
| 141 | + new_state_dicts[i]["model"][new_name] = new_tensors[i].clone() |
| 142 | + |
| 143 | + for i in range(tensor_parallel_size): |
| 144 | + output_dir_tp = os.path.join(output_path, f"iter_0000001/mp_rank_0{i}") |
| 145 | + os.makedirs(output_dir_tp, exist_ok=True) |
| 146 | + output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt") |
| 147 | + torch.save(new_state_dicts[i], output_path_tp) |
| 148 | + print("saved file", output_path_tp) |
| 149 | + |
| 150 | + print("done") |
| 151 | + |
| 152 | + |
| 153 | +if __name__ == "__main__": |
| 154 | + parser = argparse.ArgumentParser(description="InternVIT HuggingFace to Mcore converter") |
| 155 | + parser.add_argument("--model-name", type=str, default="OpenGVLab/InternViT-6B-448px-V1-5", help="Model name in HuggingFace") |
| 156 | + parser.add_argument("--output-dir", type=str, required=True, help="Output directory for the mcore model.") |
| 157 | + parser.add_argument("--use-te", action="store_true", default=True) |
| 158 | + parser.add_argument("--tensor-parallel-size", type=int, required=True) |
| 159 | + |
| 160 | + args = parser.parse_args() |
| 161 | + |
| 162 | + convert(args.model_name, args.output_dir, args.tensor_parallel_size, args.use_te) |
0 commit comments