Skip to content

Commit

Permalink
ADLR/megatron-lm!2033 - Online eval
Browse files Browse the repository at this point in the history
  • Loading branch information
trintamaki authored and ericharper committed Sep 19, 2024
1 parent 6b35ca8 commit 63be779
Show file tree
Hide file tree
Showing 14 changed files with 489 additions and 241 deletions.
25 changes: 25 additions & 0 deletions examples/multimodal/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
from dataclasses import dataclass

import torch

from megatron.training.activations import quick_gelu, squared_relu
Expand Down Expand Up @@ -107,3 +109,26 @@ def get_vision_projection_config(config, hidden_size):
config.activation_func = torch.nn.functional.gelu

return config


@dataclass
class EvaluationConfig:
"""Evaluation related configuration."""
task: str

temperature: float = 1.0
top_p: float = 0.0
top_k: int = 0

out_seq_length: int = 32

output_path: str = ""

input_image_path: str = ""
gt_path: str = ""

num_partitions: int = 1
partition_id: int = 0
num_samples_per_partition: int = 0

prompt_format: str = "mistral"
20 changes: 15 additions & 5 deletions examples/multimodal/evaluate_textvqa.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
import argparse
import glob
import json
import os

from evaluate_vqav2 import compute_vqa_accuracy


def merge_input_files(input_path):
"""Merge input files to a format compatible with the evaluator."""
output_file_path = input_path + "-TextVQA-merged.json"
# Single input file.
if os.path.exists(input_path):
input_file_paths = [input_path]
output_file_path = input_path.replace(".jsonl", "-merged.json")
# Directory of partitioned input files.
else:
pattern = input_path + "-TextVQA-[0-9].*jsonl"
input_file_paths = glob.glob(pattern)

pattern = input_path + "-TextVQA-[0-9].*jsonl"
input_file_paths = glob.glob(pattern)
output_file_path = input_path + "-TextVQA-merged.json"

results = []

Expand All @@ -35,12 +42,15 @@ def merge_input_files(input_path):
def textvqa_eval(input_path):
"""Run TextVQA evaluation."""
result_file_path = merge_input_files(input_path)
compute_vqa_accuracy(result_file_path)
avg_acc = compute_vqa_accuracy(result_file_path)
return avg_acc


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
args = parser.parse_args()

textvqa_eval(args.input_path)
avg_acc = textvqa_eval(args.input_path)

print(f"===== TextVQA Accuracy {avg_acc:.2f}% =====")
12 changes: 8 additions & 4 deletions examples/multimodal/evaluate_vqav2.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def compute_vqa_accuracy(result_file, use_chartqa_metric=False):
# "We consider an answer to be correct if it is within 5% of the gold answer.
# For non-numeric answers, we still need an exact match to consider an answer to be correct."
if use_chartqa_metric:
acc = 0.
acc = 0.0
assert len(gt) == 1, "expected exactly one groundtruth answer."
gt = gt[0]

Expand All @@ -74,18 +74,22 @@ def compute_vqa_accuracy(result_file, use_chartqa_metric=False):
all_acc.append(acc)

acc_avg = sum(all_acc) / len(all_acc) * 100
print(f"===== Accuracy {acc_avg:.2f}% =====")

return acc_avg


def vqav2_eval(input_path):
"""Run VQAv2 evaluation."""
result_file = merge_input_files(input_path)
compute_vqa_accuracy(result_file)
avg_acc = compute_vqa_accuracy(result_file)
return avg_acc


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--input-path', type=str, help="Path to input file(s)")
args = parser.parse_args()

vqav2_eval(args.input_path)
avg_acc = vqav2_eval(args.input_path)

print(f"===== VQAv2 Accuracy {avg_acc:.2f}% =====")
149 changes: 149 additions & 0 deletions examples/multimodal/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
import warnings
from copy import deepcopy

import torch
from config import get_language_model_config, get_vision_model_config, get_vision_projection_config
from layer_specs import get_layer_spec, get_layer_spec_te, get_mlp_module_spec

from megatron.core.models.multimodal.llava_model import LLaVAModel
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
from megatron.training import get_args, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args


def model_provider(
pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True
) -> LLaVAModel:
"""Builds the model.
Args:
pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). Defaults to True.
post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline parallelism). Defaults to True.
add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the encoder
will live on only a subset of the pipeline stages (specifically, only the first stage).
add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the decoder
will live on only a subset of the pipeline stages (specifically, every stage after the first one).
parallel_output (bool): Enable parallel model output.
Returns:
model: A multimodal model.
"""
args = get_args()

use_te = args.use_te

print_rank_0('building a multimodal model ...')

num_image_embeddings = get_num_image_embeddings(
args.img_h, args.img_w, args.patch_dim, args.disable_vision_class_token, 1
)
old_seq_length = args.seq_length
args.seq_length = args.encoder_seq_length = num_image_embeddings
if torch.distributed.get_rank() == 0 and old_seq_length != args.seq_length:
warnings.warn(
f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})"
)

max_num_image_embeddings = (args.max_num_tiles + int(args.use_thumbnail)) * num_image_embeddings

assert (
args.decoder_seq_length is not None
), "Please provide --decoder-seq-length to set the language model sequence length"
assert (
args.decoder_seq_length > max_num_image_embeddings
), "Language model sequence length must be greater than the maximum number of image embeddings"
if args.decoder_seq_length > args.max_position_embeddings:
args.max_position_embeddings = args.decoder_seq_length
warnings.warn(
f"Expanded max_position_embeddings to {args.max_position_embeddings} to accommodate the maximum language model sequence length"
)

base_config = core_transformer_config_from_args(get_args())
base_config.language_model_type = args.language_model_type
base_config.vision_model_type = args.vision_model_type
base_config.calculate_per_token_loss = True

language_config = deepcopy(base_config)
language_config = get_language_model_config(language_config)

if use_te:
language_transformer_layer_spec = get_layer_spec_te(
is_vit=False
) # TENorm detects LayerNorm/RMS automatically.
else:
language_transformer_layer_spec = get_layer_spec(
is_vit=False, normalization=language_config.normalization
)

vision_config = deepcopy(base_config)
vision_config = get_vision_model_config(
vision_config, apply_query_key_layer_scaling=args.apply_query_key_layer_scaling
)

vision_model_type = args.vision_model_type
if vision_model_type == "clip":
if use_te:
vision_transformer_layer_spec = get_layer_spec_te(
is_vit=True
) # TENorm detects LayerNorm/RMS automatically.
else:
vision_transformer_layer_spec = get_layer_spec(
is_vit=True, normalization=vision_config.normalization
)
else:
raise RuntimeError("unsupported vision model type", vision_model_type)

vision_projection_config = deepcopy(base_config)
vision_projection_config = get_vision_projection_config(
vision_projection_config, language_config.hidden_size
)

if args.encoder_pipeline_model_parallel_size > 0:
assert (
args.encoder_pipeline_model_parallel_size == 1
), "vision model and projection can only live on 1 pipeline stage."
vision_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size
vision_projection_config.pipeline_model_parallel_size = (
args.encoder_pipeline_model_parallel_size
)
if args.encoder_tensor_model_parallel_size > 0:
vision_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size
vision_projection_config.tensor_model_parallel_size = (
args.encoder_tensor_model_parallel_size
)

vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules

model = LLaVAModel(
language_transformer_config=language_config,
language_transformer_layer_spec=language_transformer_layer_spec,
language_vocab_size=args.padded_vocab_size,
language_max_sequence_length=args.decoder_seq_length,
vision_transformer_config=vision_config,
vision_transformer_layer_spec=vision_transformer_layer_spec,
drop_vision_class_token=args.disable_vision_class_token,
vision_projection_config=vision_projection_config,
vision_projection_layer_spec=vision_projection_layer_spec,
vision_projection_type="mlp",
allow_missing_vision_projection_checkpoint=args.allow_missing_vision_projection_checkpoint,
parallel_output=parallel_output,
language_position_embedding_type=args.position_embedding_type,
language_rotary_percent=args.rotary_percent,
pre_process=pre_process,
post_process=post_process,
add_encoder=add_encoder,
add_decoder=add_decoder,
img_h=args.img_h,
img_w=args.img_w,
patch_dim=args.patch_dim,
language_rotary_base=args.rotary_base,
)

model.freeze(
freeze_language_model=args.freeze_LM,
freeze_vision_model=args.freeze_ViT,
freeze_vision_projection=False,
)

return model
43 changes: 43 additions & 0 deletions examples/multimodal/multimodal_args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.


def add_multimodal_extra_args(parser):
"""Extra arguments."""
group = parser.add_argument_group(title='multimodal arguments')
group.add_argument('--dataset-config', type=str, default=None)
group.add_argument("--prompt-path", type=str, default=None)
group.add_argument('--freeze-LM', action='store_true', default=False)
group.add_argument('--freeze-ViT', action='store_true', default=False)
group.add_argument('--language-model-type', type=str, required=True)
group.add_argument('--vision-model-type', type=str, default="clip")
group.add_argument("--disable-vision-class-token", action="store_true", default=False)
group.add_argument(
"--allow-missing-vision-projection-checkpoint", action="store_true", default=False
)
group.add_argument("--use-te", action="store_true", default=False)
group.add_argument(
"--dataloader-save", type=str, default=None, help="Energon dataloader state save path"
)
group.add_argument(
"--use-tiling", action="store_true", default=False, help="Use input image tiling"
)
group.add_argument("--max-num-tiles", type=int, default=1, help="Maximum number of image tiles")
group.add_argument(
"--use-thumbnail", action="store_true", default=False, help="Add image thumbnail as a tile"
)
group.add_argument(
"--dataloader-seq-length",
type=int,
help="Make dataloader to produce sequences of specific length.",
)
group.add_argument(
"--num-frames",
type=int,
default=1,
help="Number of frames to regularly sample from the video as input to the model.",
)
group.add_argument(
"--online-evaluation-config", type=str, help="Config file for online evaluation."
)

return parser
2 changes: 0 additions & 2 deletions examples/multimodal/pretrain_mistral_clip.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ fi
CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}/checkpoints"

DATA_TRAIN="${SOURCE}/examples/multimodal/pretrain_dataset.yaml"
DATA_VALID="${SOURCE}/examples/multimodal/pretrain_dataset.yaml"

DEBUG=0
if [[ $DEBUG -eq 1 ]]; then
Expand Down Expand Up @@ -96,7 +95,6 @@ OPTIONS=" \
--tokenizer-type HuggingFaceTokenizer \
--tokenizer-model ${WORKSPACE}/${TOKENIZER_MODEL} \
--data-path ${DATA_TRAIN} \
--valid-path ${DATA_VALID} \
--prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \
--save-interval 1000 \
--save ${FINETUNE_DIR} \
Expand Down
Loading

0 comments on commit 63be779

Please sign in to comment.