Skip to content

Commit

Permalink
Merge branch 'trintamaki/small-model-example' into 'main'
Browse files Browse the repository at this point in the history
Small VLM example

See merge request ADLR/megatron-lm!2432
  • Loading branch information
trintamaki committed Dec 11, 2024
2 parents 215a2eb + 2aa3522 commit 371feef
Show file tree
Hide file tree
Showing 15 changed files with 324 additions and 91 deletions.
50 changes: 17 additions & 33 deletions examples/multimodal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,20 @@


def get_language_model_config(config):
if config.language_model_type == "2b":
if config.language_model_type == "llama3_8b":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.bias_activation_fusion = False
config.gated_linear_unit = True
config.apply_query_key_layer_scaling = True
config.layernorm_zero_centered_gamma = True
config.bias_dropout_fusion = False
config.rotary_percent = 0.5
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
elif config.language_model_type == "8b":
config.add_bias_linear = False
config.bias_activation_fusion = False
config.gated_linear_unit = False
config.apply_query_key_layer_scaling = True
config.layernorm_zero_centered_gamma = True
config.apply_query_key_layer_scaling = False
config.layernorm_zero_centered_gamma = (
False # Zero centered gamma not supported for RMSNorm
)
config.bias_dropout_fusion = False
config.rotary_percent = 0.5
config.attention_dropout = 0.0
config.apply_rope_fusion = False
config.activation_func = squared_relu
config.ffn_hidden_size = 16384
config.masked_softmax_fusion = True
config.attention_softmax_in_fp32 = True
config.num_query_groups = 32
config.kv_channels = 128
config.rotary_interleaved = False
elif config.language_model_type == "llama3_8b":
config.ffn_hidden_size = 14336
elif config.language_model_type == "mistral_7b":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.bias_activation_fusion = False
Expand All @@ -47,7 +33,7 @@ def get_language_model_config(config):
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 14336
elif config.language_model_type == "mistral_7b":
elif config.language_model_type == "yi-34b":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.bias_activation_fusion = False
Expand All @@ -59,10 +45,11 @@ def get_language_model_config(config):
config.bias_dropout_fusion = False
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 14336
elif config.language_model_type == "yi-34b":
config.ffn_hidden_size = 20480
elif config.language_model_type == "qwen2.5_7B":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
config.add_qkv_bias = True
config.bias_activation_fusion = False
config.gated_linear_unit = True
config.apply_query_key_layer_scaling = False
Expand All @@ -72,7 +59,7 @@ def get_language_model_config(config):
config.bias_dropout_fusion = False
config.apply_rope_fusion = False
config.attention_softmax_in_fp32 = True
config.ffn_hidden_size = 20480
config.ffn_hidden_size = 18944
elif config.language_model_type == "qwen2.0_72B":
config.activation_func = torch.nn.functional.silu
config.add_bias_linear = False
Expand Down Expand Up @@ -168,13 +155,7 @@ def get_vision_projection_config(config, hidden_size):
config.bias_activation_fusion = False
config.add_bias_linear = False
config.hidden_size = hidden_size # Used as the vision projection output size, i.e., the input to the language model.
if config.language_model_type == "2b":
config.ffn_hidden_size = 5440
config.activation_func = torch.nn.functional.gelu
if config.language_model_type == "8b":
config.ffn_hidden_size = 16384
config.activation_func = squared_relu
elif config.language_model_type == "llama3_8b":
if config.language_model_type == "llama3_8b":
config.ffn_hidden_size = 14336
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "mistral_7b":
Expand All @@ -185,6 +166,9 @@ def get_vision_projection_config(config, hidden_size):
config.ffn_hidden_size = 20480
config.normalization = "LayerNorm"
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "qwen2.5_7B":
config.ffn_hidden_size = 3584
config.activation_func = torch.nn.functional.gelu
elif config.language_model_type == "qwen2.0_72B":
config.ffn_hidden_size = 29568
config.normalization = "LayerNorm"
Expand Down
22 changes: 14 additions & 8 deletions examples/multimodal/evaluate_ai2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,25 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="AI2D")

results = []
results = dict()

for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(
{
"question_id": res["sample_id"],
"answer": res["answer"],
"gt_answer": res["gt_answer"],
}
)
sample_id = res["sample_id"]

# Ignore possible duplicates.
if sample_id in results:
continue

results[sample_id] = {
"question_id": sample_id,
"answer": res["answer"],
"gt_answer": res["gt_answer"],
}

results = list(results.values())

with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
Expand Down
13 changes: 10 additions & 3 deletions examples/multimodal/evaluate_chartqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,22 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="ChartQA")

results = []
results = dict()

for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
res["question_id"] = res["sample_id"]
sample_id = res["sample_id"]

results.append(res)
# Ignore possible duplicates.
if sample_id in results:
continue

res["question_id"] = sample_id
results[sample_id] = res

results = list(results.values())

with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
Expand Down
18 changes: 13 additions & 5 deletions examples/multimodal/evaluate_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,28 @@ def convert_to_coco_format(input_path):
"""Convert input files to COCO compatible format."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="captioning")

captions = []
results = dict()

for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
sample_id = res["sample_id"]

question_id = res['sample_id']
caption = res['caption'].rstrip('.').lower()
# Ignore possible duplicates.
if sample_id in results:
continue

captions.append({"image_id": question_id, "caption": caption})
caption = res["caption"].rstrip(".").lower()
results[sample_id] = {
"image_id": sample_id,
"caption": caption,
}

results = list(results.values())

with open(output_file_path, "w") as output_file:
json.dump(captions, output_file, indent=4)
json.dump(results, output_file, indent=4)

return output_file_path

Expand Down
12 changes: 10 additions & 2 deletions examples/multimodal/evaluate_mathvista.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,21 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="MathVista")

results = []
results = dict()

for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(res)
sample_id = res["sample_id"]

# Remove possible duplicates.
if sample_id in results:
continue

results[sample_id] = res

results = list(results.values())

with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
Expand Down
4 changes: 4 additions & 0 deletions examples/multimodal/evaluate_mmmu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def convert_to_mmmu_format(input_path):
)

# MMMU eval script expects just a sample_id to prediction mapping.
# Skip possible duplicates.
if sample_id in output:
continue

output[sample_id] = prediction

with open(output_file_path, "w") as output_file:
Expand Down
12 changes: 10 additions & 2 deletions examples/multimodal/evaluate_ocrbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,21 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="OCRBench")

results = []
results = dict()

for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(res)
sample_id = res["sample_id"]

# Remove possible duplicates.
if sample_id in results:
continue

results[sample_id] = res

results = list(results.values())

with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
Expand Down
25 changes: 14 additions & 11 deletions examples/multimodal/evaluate_textvqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,25 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="TextVQA")

results = []
results = dict()

for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
results.append(
{
"question_id": res["sample_id"],
"answer": res["answer"],
"gt_answer": res["gt_answer"],
}
)

# Make order deterministic.
# results = sorted(results, key=lambda d: d["question_id"])
sample_id = res["sample_id"]

# Remove possible duplicates.
if sample_id in results:
continue

results[sample_id] = {
"question_id": sample_id,
"answer": res["answer"],
"gt_answer": res["gt_answer"],
}

results = list(results.values())

with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
Expand Down
16 changes: 13 additions & 3 deletions examples/multimodal/evaluate_vqav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,22 @@ def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
input_file_paths, output_file_path = get_input_output_paths(input_path, task="VQAv2")

results = []
results = dict()

for input_file_path in input_file_paths:
with open(input_file_path, "r") as input_file:
for line in input_file:
res = json.loads(line)
res["question_id"] = res["sample_id"]
sample_id = res["sample_id"]

results.append(res)
# Skip possible duplicates.
if sample_id in results:
continue

res["question_id"] = sample_id
results[sample_id] = res

results = list(results.values())

with open(output_file_path, "w") as output_file:
json.dump(results, output_file)
Expand Down Expand Up @@ -57,6 +64,9 @@ def compute_vqa_accuracy(result_file, task):
assert len(gt) == 1, "expected exactly one groundtruth answer."
gt = gt[0]

pred = pred.rstrip("%")
gt = gt.rstrip("%")

if is_number(pred) and is_number(gt):
pred = float(pred)
gt = float(gt)
Expand Down
Loading

0 comments on commit 371feef

Please sign in to comment.