Skip to content

Commit c2e9fb5

Browse files
committed
Merge branch 'trintamaki/internvit' into 'main'
InternViT support for NVLM See merge request ADLR/megatron-lm!2295
2 parents 32fc18a + 95ea6e5 commit c2e9fb5

File tree

11 files changed

+672
-16
lines changed

11 files changed

+672
-16
lines changed

examples/multimodal/config.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,21 @@ def get_language_model_config(config):
6060
config.apply_rope_fusion = False
6161
config.attention_softmax_in_fp32 = True
6262
config.ffn_hidden_size = 14336
63+
elif config.language_model_type == "yi-34b":
64+
config.activation_func = torch.nn.functional.silu
65+
config.add_bias_linear = False
66+
config.bias_activation_fusion = False
67+
config.gated_linear_unit = True
68+
config.apply_query_key_layer_scaling = False
69+
config.layernorm_zero_centered_gamma = (
70+
False # Zero centered gamma not supported for RMSNorm
71+
)
72+
config.bias_dropout_fusion = False
73+
config.apply_rope_fusion = False
74+
config.attention_softmax_in_fp32 = True
75+
config.ffn_hidden_size = 20480
76+
else:
77+
raise ValueError(f"unknown language model type {config.language_model_type}")
6378

6479
return config
6580

@@ -107,6 +122,30 @@ def get_vision_model_config(config, apply_query_key_layer_scaling):
107122
config.apply_rope_fusion = False
108123
config.qk_layernorm = False
109124
config.layernorm_epsilon = 1e-6
125+
elif config.vision_model_type == "internvit":
126+
config.num_layers = 45
127+
config.num_attention_heads = 32 # Padded for TP=8.
128+
config.num_query_groups = 32 # Padded for TP=8.
129+
config.kv_channels = 128
130+
config.add_bias_linear = True
131+
config.add_qkv_bias = False
132+
config.hidden_size = 3200
133+
config.hidden_dropout = 0.0
134+
config.attention_dropout = 0.0
135+
config.ffn_hidden_size = 12800
136+
config.gated_linear_unit = False
137+
config.activation_func = torch.nn.functional.gelu
138+
config.layernorm_zero_centered_gamma = False
139+
config.apply_query_key_layer_scaling = apply_query_key_layer_scaling
140+
config.bias_activation_fusion = False
141+
config.bias_dropout_fusion = False
142+
config.attention_softmax_in_fp32 = True
143+
config.normalization = 'RMSNorm'
144+
config.layernorm_epsilon = 1e-6
145+
config.apply_rope_fusion = False
146+
else:
147+
raise ValueError(f"unknown vision model type {config.vision_model_type}")
148+
110149

111150
return config
112151

@@ -128,6 +167,12 @@ def get_vision_projection_config(config, hidden_size):
128167
elif config.language_model_type == "mistral_7b":
129168
config.ffn_hidden_size = 14336
130169
config.activation_func = torch.nn.functional.gelu
170+
elif config.language_model_type == "yi-34b":
171+
config.ffn_hidden_size = 20480
172+
config.normalization = 'LayerNorm'
173+
config.activation_func = torch.nn.functional.gelu
174+
else:
175+
raise ValueError(f"unknown language model type {config.language_model_type}")
131176

132177
return config
133178

examples/multimodal/image_processing.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@
77
from torchvision.transforms import Compose, RandAugment, RandomResizedCrop, Resize, ToPILImage
88

99

10-
# Imagenet's mean and std.
11-
pixel_mean = [123.675, 116.28, 103.53]
12-
pixel_std = [58.395, 57.12, 57.375]
13-
1410
# Reshape for broadcasting.
15-
pixel_mean = torch.Tensor(pixel_mean).view(-1, 1, 1)
16-
pixel_std = torch.Tensor(pixel_std).view(-1, 1, 1)
11+
pixel_mean_clip = torch.Tensor([123.675, 116.28, 103.53]).view(-1, 1, 1)
12+
pixel_std_clip = torch.Tensor([58.395, 57.12, 57.375]).view(-1, 1, 1)
13+
14+
pixel_mean_siglip = torch.Tensor([127.5, 127.5, 127.5]).view(-1, 1, 1)
15+
pixel_std_siglip = torch.Tensor([127.5, 127.5, 127.5]).view(-1, 1, 1)
16+
17+
pixel_statistics = {
18+
"clip": (pixel_mean_clip, pixel_std_clip),
19+
"siglip": (pixel_mean_siglip, pixel_std_siglip),
20+
"internvit": (pixel_mean_clip, pixel_std_clip),
21+
}
1722

1823

1924
def convert_to_rgb(image):
@@ -36,12 +41,14 @@ def _transform_test(img_h, img_w):
3641
])
3742

3843

39-
def standardize_image(img):
44+
def standardize_image(img, mean, std):
4045
"""Standardize image pixel values."""
41-
return (torch.Tensor(np.array(img)).permute(2, 0, 1) - pixel_mean) / pixel_std
46+
return (torch.Tensor(np.array(img)).permute(2, 0, 1) - mean) / std
47+
4248

49+
def get_visual_transform(img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False, vision_model_type="clip"):
50+
pixel_mean, pixel_std = pixel_statistics[vision_model_type]
4351

44-
def get_visual_transform(img, img_h, img_w, use_tiling=False, max_num_tiles=1, use_thumbnail=False, augment=False):
4552
if use_tiling:
4653
assert img_h == img_w, "dynamic tiling expects equal tile height and width"
4754
imgs = dynamic_preprocess(img, min_num=1, max_num=max_num_tiles, image_size=img_h, use_thumbnail=use_thumbnail)
@@ -60,7 +67,7 @@ def get_visual_transform(img, img_h, img_w, use_tiling=False, max_num_tiles=1, u
6067
img = visual_transform(img)
6168

6269
# Standardize pixel values.
63-
img = standardize_image(img)
70+
img = standardize_image(img, pixel_mean, pixel_std)
6471

6572
# Pad to target image size.
6673
delta_h, delta_w = img_h - scaled_h, img_w - scaled_w

examples/multimodal/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def model_provider(
3737

3838
num_image_embeddings = get_num_image_embeddings(
3939
args.img_h, args.img_w, args.patch_dim, args.vision_model_type,
40-
args.disable_vision_class_token, 1
40+
args.disable_vision_class_token, 1, args.pixel_shuffle,
4141
)
4242
old_seq_length = args.seq_length
4343
args.seq_length = args.encoder_seq_length = num_image_embeddings
@@ -92,6 +92,9 @@ def model_provider(
9292
vision_transformer_layer_spec = get_layer_spec(
9393
is_vit=True, normalization=vision_config.normalization
9494
)
95+
elif vision_model_type == "internvit":
96+
from nvlm.internvit import get_internvit_layer_spec
97+
vision_transformer_layer_spec = get_internvit_layer_spec(use_te=use_te)
9598
else:
9699
raise RuntimeError("unsupported vision model type", vision_model_type)
97100

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
from transformers import AutoModel
6+
7+
8+
def convert(model_name, output_path, tensor_parallel_size, use_te):
9+
"""Convert InternViT HF checkpoint to mcore."""
10+
hf_model = AutoModel.from_pretrained(
11+
model_name,
12+
trust_remote_code=True
13+
)
14+
15+
hf_state_dict = hf_model.state_dict()
16+
new_state_dicts = [{"model": dict()} for _ in range(tensor_parallel_size)]
17+
18+
hidden_size = 3200
19+
num_heads = 25
20+
dim = 128
21+
22+
order = torch.ones(3 * hidden_size).long()
23+
24+
for j in range(num_heads):
25+
for i in range(dim):
26+
order[i + dim*3*j] = j*dim+i
27+
order[dim + i + dim*3*j] = j*dim+i+num_heads*dim
28+
order[dim*2 + i + dim*3*j] = j*dim+i+num_heads*dim*2
29+
30+
for name, tensor in hf_state_dict.items():
31+
# Map parameter names to ones used in megatron.
32+
new_name = ""
33+
new_tensor = tensor
34+
35+
# This is used for chunking some tensors to target tensor parallel size.
36+
chunk_dim = None
37+
38+
if "embeddings.class_embedding" in name:
39+
new_name = "class_token"
40+
elif "embeddings.patch_embedding.weight" in name:
41+
new_name = "conv1.weight"
42+
elif "embeddings.patch_embedding.bias" in name:
43+
new_name = "conv1.bias"
44+
elif "embeddings.position_embedding" in name:
45+
new_name = "position_embeddings.weight"
46+
new_tensor = new_tensor.squeeze(0)
47+
elif "encoder.layers" in name:
48+
layer_idx = name.split(".")[2]
49+
50+
base = f"decoder.layers.{layer_idx}"
51+
52+
head_dim = 128
53+
54+
if tensor_parallel_size == 1:
55+
num_padded_heads = 25
56+
elif tensor_parallel_size == 8:
57+
# Note: 25 is not divisible by 8 and we don't currently support uneven heads split with tensor parallelism.
58+
# So we pad with dummy all-zero heads. Please use a nice even number of attention heads in your model.
59+
num_padded_heads = 32
60+
else:
61+
raise NotImplementedError("invalid tensor parallel size value:", tensor_parallel_size)
62+
63+
if "ls1" in name:
64+
new_name = f"{base}.ls1"
65+
elif "ls2" in name:
66+
new_name = f"{base}.ls2"
67+
elif "attn.qkv.weight" in name:
68+
new_name = f"{base}.self_attention.linear_qkv.weight"
69+
num_tensors = 3
70+
padded_dim = head_dim * num_padded_heads * num_tensors
71+
padded_tensor = torch.zeros((padded_dim, new_tensor.shape[-1]), dtype=new_tensor.dtype, device=new_tensor.device)
72+
padded_tensor[:new_tensor.shape[0], :] = new_tensor[order]
73+
new_tensor = padded_tensor
74+
chunk_dim = 0
75+
elif "attn.q_norm.weight" in name:
76+
new_name = f"{base}.self_attention.q_layernorm.weight"
77+
num_tensors = 1
78+
padded_dim = head_dim * num_padded_heads * num_tensors
79+
padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device)
80+
padded_tensor[:new_tensor.shape[0]] = new_tensor
81+
new_tensor = padded_tensor
82+
chunk_dim = 0
83+
elif "attn.k_norm.weight" in name:
84+
new_name = f"{base}.self_attention.k_layernorm.weight"
85+
num_tensors = 1
86+
padded_dim = head_dim * num_padded_heads * num_tensors
87+
padded_tensor = torch.zeros(padded_dim, dtype=new_tensor.dtype, device=new_tensor.device)
88+
padded_tensor[:new_tensor.shape[0]] = new_tensor
89+
new_tensor = padded_tensor
90+
chunk_dim = 0
91+
elif "attn.proj.weight" in name:
92+
new_name = f"{base}.self_attention.linear_proj.weight"
93+
num_tensors = 1
94+
padded_dim = head_dim * num_padded_heads * num_tensors
95+
padded_tensor = torch.zeros((new_tensor.shape[0], padded_dim), dtype=new_tensor.dtype, device=new_tensor.device)
96+
padded_tensor[:, :new_tensor.shape[-1]] = new_tensor
97+
new_tensor = padded_tensor
98+
chunk_dim = 1
99+
elif "attn.proj.bias" in name:
100+
new_name = f"{base}.self_attention.linear_proj.bias"
101+
elif "mlp.fc1.weight" in name:
102+
new_name = f"{base}.mlp.linear_fc1.weight"
103+
chunk_dim = 0
104+
elif "mlp.fc1.bias" in name:
105+
new_name = f"{base}.mlp.linear_fc1.bias"
106+
chunk_dim = 0
107+
elif "mlp.fc2.weight" in name:
108+
new_name = f"{base}.mlp.linear_fc2.weight"
109+
chunk_dim = 1
110+
elif "mlp.fc2.bias" in name:
111+
new_name = f"{base}.mlp.linear_fc2.bias"
112+
elif "norm1" in name:
113+
new_name = f"{base}.input_layernorm.weight"
114+
elif "norm2" in name:
115+
new_name = f"{base}.pre_mlp_layernorm.weight"
116+
else:
117+
raise RuntimeError("unexpected transformer layer name", name)
118+
else:
119+
raise RuntimeError("unexpected layer name", name)
120+
121+
assert new_name != "", f"unexpected layer name {name}"
122+
123+
# TE sets _extra_state (for FP8 purposes), so set an empty one here for compatibility.
124+
extra_state_layers = ("linear_qkv", "linear_proj", "linear_fc1", "linear_fc2")
125+
is_extra_state_layer = any([l in new_name for l in extra_state_layers])
126+
if use_te and is_extra_state_layer:
127+
layer = new_name.split(".")[-2]
128+
if layer in extra_state_layers:
129+
extra_state_name = (
130+
new_name[: new_name.rfind(".") + 1] + "_extra_state"
131+
) # Replace the weight name.
132+
for i in range(tensor_parallel_size):
133+
new_state_dicts[i]["model"][extra_state_name] = None
134+
135+
if chunk_dim is None:
136+
new_tensors = [new_tensor for _ in range(tensor_parallel_size)]
137+
else:
138+
new_tensors = torch.chunk(new_tensor, tensor_parallel_size, dim=chunk_dim)
139+
140+
for i in range(tensor_parallel_size):
141+
new_state_dicts[i]["model"][new_name] = new_tensors[i].clone()
142+
143+
for i in range(tensor_parallel_size):
144+
output_dir_tp = os.path.join(output_path, f"iter_0000001/mp_rank_0{i}")
145+
os.makedirs(output_dir_tp, exist_ok=True)
146+
output_path_tp = os.path.join(output_dir_tp, "model_optim_rng.pt")
147+
torch.save(new_state_dicts[i], output_path_tp)
148+
print("saved file", output_path_tp)
149+
150+
print("done")
151+
152+
153+
if __name__ == "__main__":
154+
parser = argparse.ArgumentParser(description="InternVIT HuggingFace to Mcore converter")
155+
parser.add_argument("--model-name", type=str, default="OpenGVLab/InternViT-6B-448px-V1-5", help="Model name in HuggingFace")
156+
parser.add_argument("--output-dir", type=str, required=True, help="Output directory for the mcore model.")
157+
parser.add_argument("--use-te", action="store_true", default=True)
158+
parser.add_argument("--tensor-parallel-size", type=int, required=True)
159+
160+
args = parser.parse_args()
161+
162+
convert(args.model_name, args.output_dir, args.tensor_parallel_size, args.use_te)

0 commit comments

Comments
 (0)