diff --git a/examples/multimodal/config.py b/examples/multimodal/config.py index f8c3714eb3..d4ee17db1b 100644 --- a/examples/multimodal/config.py +++ b/examples/multimodal/config.py @@ -1,4 +1,6 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass + import torch from megatron.training.activations import quick_gelu, squared_relu @@ -107,3 +109,26 @@ def get_vision_projection_config(config, hidden_size): config.activation_func = torch.nn.functional.gelu return config + + +@dataclass +class EvaluationConfig: + """Evaluation related configuration.""" + task: str + + temperature: float = 1.0 + top_p: float = 0.0 + top_k: int = 0 + + out_seq_length: int = 32 + + output_path: str = "" + + input_image_path: str = "" + gt_path: str = "" + + num_partitions: int = 1 + partition_id: int = 0 + num_samples_per_partition: int = 0 + + prompt_format: str = "mistral" diff --git a/examples/multimodal/evaluate_textvqa.py b/examples/multimodal/evaluate_textvqa.py index b80974a893..7d0a059f4d 100644 --- a/examples/multimodal/evaluate_textvqa.py +++ b/examples/multimodal/evaluate_textvqa.py @@ -1,16 +1,23 @@ import argparse import glob import json +import os from evaluate_vqav2 import compute_vqa_accuracy def merge_input_files(input_path): """Merge input files to a format compatible with the evaluator.""" - output_file_path = input_path + "-TextVQA-merged.json" + # Single input file. + if os.path.exists(input_path): + input_file_paths = [input_path] + output_file_path = input_path.replace(".jsonl", "-merged.json") + # Directory of partitioned input files. + else: + pattern = input_path + "-TextVQA-[0-9].*jsonl" + input_file_paths = glob.glob(pattern) - pattern = input_path + "-TextVQA-[0-9].*jsonl" - input_file_paths = glob.glob(pattern) + output_file_path = input_path + "-TextVQA-merged.json" results = [] @@ -35,7 +42,8 @@ def merge_input_files(input_path): def textvqa_eval(input_path): """Run TextVQA evaluation.""" result_file_path = merge_input_files(input_path) - compute_vqa_accuracy(result_file_path) + avg_acc = compute_vqa_accuracy(result_file_path) + return avg_acc if __name__ == "__main__": @@ -43,4 +51,6 @@ def textvqa_eval(input_path): parser.add_argument('--input-path', type=str, help="Path to input file(s)") args = parser.parse_args() - textvqa_eval(args.input_path) + avg_acc = textvqa_eval(args.input_path) + + print(f"===== TextVQA Accuracy {avg_acc:.2f}% =====") diff --git a/examples/multimodal/evaluate_vqav2.py b/examples/multimodal/evaluate_vqav2.py index 5d9dfe7844..cf10a0549d 100644 --- a/examples/multimodal/evaluate_vqav2.py +++ b/examples/multimodal/evaluate_vqav2.py @@ -55,7 +55,7 @@ def compute_vqa_accuracy(result_file, use_chartqa_metric=False): # "We consider an answer to be correct if it is within 5% of the gold answer. # For non-numeric answers, we still need an exact match to consider an answer to be correct." if use_chartqa_metric: - acc = 0. + acc = 0.0 assert len(gt) == 1, "expected exactly one groundtruth answer." gt = gt[0] @@ -74,13 +74,15 @@ def compute_vqa_accuracy(result_file, use_chartqa_metric=False): all_acc.append(acc) acc_avg = sum(all_acc) / len(all_acc) * 100 - print(f"===== Accuracy {acc_avg:.2f}% =====") + + return acc_avg def vqav2_eval(input_path): """Run VQAv2 evaluation.""" result_file = merge_input_files(input_path) - compute_vqa_accuracy(result_file) + avg_acc = compute_vqa_accuracy(result_file) + return avg_acc if __name__ == "__main__": @@ -88,4 +90,6 @@ def vqav2_eval(input_path): parser.add_argument('--input-path', type=str, help="Path to input file(s)") args = parser.parse_args() - vqav2_eval(args.input_path) + avg_acc = vqav2_eval(args.input_path) + + print(f"===== VQAv2 Accuracy {avg_acc:.2f}% =====") diff --git a/examples/multimodal/model.py b/examples/multimodal/model.py new file mode 100644 index 0000000000..b21c687525 --- /dev/null +++ b/examples/multimodal/model.py @@ -0,0 +1,149 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +import warnings +from copy import deepcopy + +import torch +from config import get_language_model_config, get_vision_model_config, get_vision_projection_config +from layer_specs import get_layer_spec, get_layer_spec_te, get_mlp_module_spec + +from megatron.core.models.multimodal.llava_model import LLaVAModel +from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings +from megatron.training import get_args, print_rank_0 +from megatron.training.arguments import core_transformer_config_from_args + + +def model_provider( + pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True +) -> LLaVAModel: + """Builds the model. + + Args: + pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). Defaults to True. + post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline parallelism). Defaults to True. + add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the encoder + will live on only a subset of the pipeline stages (specifically, only the first stage). + add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the decoder + will live on only a subset of the pipeline stages (specifically, every stage after the first one). + parallel_output (bool): Enable parallel model output. + + Returns: + model: A multimodal model. + """ + args = get_args() + + use_te = args.use_te + + print_rank_0('building a multimodal model ...') + + num_image_embeddings = get_num_image_embeddings( + args.img_h, args.img_w, args.patch_dim, args.disable_vision_class_token, 1 + ) + old_seq_length = args.seq_length + args.seq_length = args.encoder_seq_length = num_image_embeddings + if torch.distributed.get_rank() == 0 and old_seq_length != args.seq_length: + warnings.warn( + f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})" + ) + + max_num_image_embeddings = (args.max_num_tiles + int(args.use_thumbnail)) * num_image_embeddings + + assert ( + args.decoder_seq_length is not None + ), "Please provide --decoder-seq-length to set the language model sequence length" + assert ( + args.decoder_seq_length > max_num_image_embeddings + ), "Language model sequence length must be greater than the maximum number of image embeddings" + if args.decoder_seq_length > args.max_position_embeddings: + args.max_position_embeddings = args.decoder_seq_length + warnings.warn( + f"Expanded max_position_embeddings to {args.max_position_embeddings} to accommodate the maximum language model sequence length" + ) + + base_config = core_transformer_config_from_args(get_args()) + base_config.language_model_type = args.language_model_type + base_config.vision_model_type = args.vision_model_type + base_config.calculate_per_token_loss = True + + language_config = deepcopy(base_config) + language_config = get_language_model_config(language_config) + + if use_te: + language_transformer_layer_spec = get_layer_spec_te( + is_vit=False + ) # TENorm detects LayerNorm/RMS automatically. + else: + language_transformer_layer_spec = get_layer_spec( + is_vit=False, normalization=language_config.normalization + ) + + vision_config = deepcopy(base_config) + vision_config = get_vision_model_config( + vision_config, apply_query_key_layer_scaling=args.apply_query_key_layer_scaling + ) + + vision_model_type = args.vision_model_type + if vision_model_type == "clip": + if use_te: + vision_transformer_layer_spec = get_layer_spec_te( + is_vit=True + ) # TENorm detects LayerNorm/RMS automatically. + else: + vision_transformer_layer_spec = get_layer_spec( + is_vit=True, normalization=vision_config.normalization + ) + else: + raise RuntimeError("unsupported vision model type", vision_model_type) + + vision_projection_config = deepcopy(base_config) + vision_projection_config = get_vision_projection_config( + vision_projection_config, language_config.hidden_size + ) + + if args.encoder_pipeline_model_parallel_size > 0: + assert ( + args.encoder_pipeline_model_parallel_size == 1 + ), "vision model and projection can only live on 1 pipeline stage." + vision_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size + vision_projection_config.pipeline_model_parallel_size = ( + args.encoder_pipeline_model_parallel_size + ) + if args.encoder_tensor_model_parallel_size > 0: + vision_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size + vision_projection_config.tensor_model_parallel_size = ( + args.encoder_tensor_model_parallel_size + ) + + vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules + + model = LLaVAModel( + language_transformer_config=language_config, + language_transformer_layer_spec=language_transformer_layer_spec, + language_vocab_size=args.padded_vocab_size, + language_max_sequence_length=args.decoder_seq_length, + vision_transformer_config=vision_config, + vision_transformer_layer_spec=vision_transformer_layer_spec, + drop_vision_class_token=args.disable_vision_class_token, + vision_projection_config=vision_projection_config, + vision_projection_layer_spec=vision_projection_layer_spec, + vision_projection_type="mlp", + allow_missing_vision_projection_checkpoint=args.allow_missing_vision_projection_checkpoint, + parallel_output=parallel_output, + language_position_embedding_type=args.position_embedding_type, + language_rotary_percent=args.rotary_percent, + pre_process=pre_process, + post_process=post_process, + add_encoder=add_encoder, + add_decoder=add_decoder, + img_h=args.img_h, + img_w=args.img_w, + patch_dim=args.patch_dim, + language_rotary_base=args.rotary_base, + ) + + model.freeze( + freeze_language_model=args.freeze_LM, + freeze_vision_model=args.freeze_ViT, + freeze_vision_projection=False, + ) + + return model diff --git a/examples/multimodal/multimodal_args.py b/examples/multimodal/multimodal_args.py new file mode 100644 index 0000000000..a7cb4235e3 --- /dev/null +++ b/examples/multimodal/multimodal_args.py @@ -0,0 +1,43 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + + +def add_multimodal_extra_args(parser): + """Extra arguments.""" + group = parser.add_argument_group(title='multimodal arguments') + group.add_argument('--dataset-config', type=str, default=None) + group.add_argument("--prompt-path", type=str, default=None) + group.add_argument('--freeze-LM', action='store_true', default=False) + group.add_argument('--freeze-ViT', action='store_true', default=False) + group.add_argument('--language-model-type', type=str, required=True) + group.add_argument('--vision-model-type', type=str, default="clip") + group.add_argument("--disable-vision-class-token", action="store_true", default=False) + group.add_argument( + "--allow-missing-vision-projection-checkpoint", action="store_true", default=False + ) + group.add_argument("--use-te", action="store_true", default=False) + group.add_argument( + "--dataloader-save", type=str, default=None, help="Energon dataloader state save path" + ) + group.add_argument( + "--use-tiling", action="store_true", default=False, help="Use input image tiling" + ) + group.add_argument("--max-num-tiles", type=int, default=1, help="Maximum number of image tiles") + group.add_argument( + "--use-thumbnail", action="store_true", default=False, help="Add image thumbnail as a tile" + ) + group.add_argument( + "--dataloader-seq-length", + type=int, + help="Make dataloader to produce sequences of specific length.", + ) + group.add_argument( + "--num-frames", + type=int, + default=1, + help="Number of frames to regularly sample from the video as input to the model.", + ) + group.add_argument( + "--online-evaluation-config", type=str, help="Config file for online evaluation." + ) + + return parser diff --git a/examples/multimodal/pretrain_mistral_clip.sh b/examples/multimodal/pretrain_mistral_clip.sh index da72c335c0..b06dbfe53c 100755 --- a/examples/multimodal/pretrain_mistral_clip.sh +++ b/examples/multimodal/pretrain_mistral_clip.sh @@ -32,7 +32,6 @@ fi CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}/checkpoints" DATA_TRAIN="${SOURCE}/examples/multimodal/pretrain_dataset.yaml" -DATA_VALID="${SOURCE}/examples/multimodal/pretrain_dataset.yaml" DEBUG=0 if [[ $DEBUG -eq 1 ]]; then @@ -96,7 +95,6 @@ OPTIONS=" \ --tokenizer-type HuggingFaceTokenizer \ --tokenizer-model ${WORKSPACE}/${TOKENIZER_MODEL} \ --data-path ${DATA_TRAIN} \ - --valid-path ${DATA_VALID} \ --prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \ --save-interval 1000 \ --save ${FINETUNE_DIR} \ diff --git a/examples/multimodal/run_text_generation.py b/examples/multimodal/run_text_generation.py index 391f3071d0..bc406217b7 100644 --- a/examples/multimodal/run_text_generation.py +++ b/examples/multimodal/run_text_generation.py @@ -1,13 +1,13 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. """Generate text using a vision language model.""" import glob +import itertools import json import logging import os import sys from collections import defaultdict from functools import partial -import itertools # Add megatron to the path. sys.path.append( @@ -17,7 +17,8 @@ import datasets import numpy as np import torch -from torchvision.io import read_video +import yaml +from config import EvaluationConfig from dataset_helpers import tokenizer_image_token from image_processing import get_visual_transform from MMMU.mmmu.utils.data_utils import ( @@ -27,10 +28,13 @@ process_single_sample, ) from MMMU.mmmu.utils.eval_utils import parse_multi_choice_response +from model import model_provider +from multimodal_args import add_multimodal_extra_args from PIL import Image -from train import add_multimodal_extra_args, get_num_image_embeddings, model_provider +from torchvision.io import read_video from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN_INDEX +from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings from megatron.inference.text_generation.api import generate_and_post_process from megatron.inference.text_generation.forward_step import ForwardStep from megatron.training import get_args, get_model, get_tokenizer, print_rank_0 @@ -48,14 +52,12 @@ def add_text_generation_args(parser): group.add_argument( "--out-seq-length", type=int, default=1024, help='Length of the output generated text.' ) - group.add_argument("--output-path", type=str, required=True, help='Output file path') - group.add_argument('--input-image-path', type=str, required=True, help="Input image directory") - group.add_argument('--input-metadata-path', type=str, help="Input metadata path") + group.add_argument("--output-path", type=str, help='Output file path') + group.add_argument('--input-image-path', type=str, help="Input image directory") group.add_argument( '--num-partitions', type=int, default=0, help="Number of partitions for inputs." ) group.add_argument('--partition-id', type=int, default=0, help="Partition index") - group.add_argument("--drop-vision-class-token", action="store_true", default=False) group.add_argument("--gt-path", type=str, help="Optional ground truth file") group.add_argument( "--task", @@ -69,10 +71,11 @@ def add_text_generation_args(parser): group.add_argument( "--prompt-format", type=str, - required=True, + default="mistral", choices=["llama3", "mistral"], help="Prompting format to use", ) + group.add_argument("--config-path", type=str, help="Config file to use.") # Add common multimodal arguments needed for e.g. building the model. parser = add_multimodal_extra_args(parser) @@ -85,8 +88,9 @@ def _get_partition_bounds( ): if num_samples_per_partition == 0: samples_per_partition = [ - int(x) for x in np.linspace(0, total_num_samples, num_partitions+1)] - return samples_per_partition[partition_id], samples_per_partition[partition_id+1] + int(x) for x in np.linspace(0, total_num_samples, num_partitions + 1) + ] + return samples_per_partition[partition_id], samples_per_partition[partition_id + 1] return num_samples_per_partition * partition_id, num_samples_per_partition * (partition_id + 1) @@ -286,33 +290,34 @@ def get_evaluation_dataset( continue gt["video_path"] = video_path ground_truth.append(gt) - + ground_truth = sorted(ground_truth, key=lambda gt: gt["video_path"]) print_rank_0(f"Found {len(ground_truth)} videos to process.") if num_partitions > 0: start_idx, end_idx = _get_partition_bounds( - len(ground_truth), num_samples_per_partition, - num_partitions, partition_id + len(ground_truth), num_samples_per_partition, num_partitions, partition_id ) ground_truth = ground_truth[start_idx:end_idx] # Run image preprocessing. for idx, gt in enumerate(ground_truth): print_rank_0(f"Processing input video: {idx} / {len(ground_truth)}") - video, _, _ = read_video( - gt["video_path"], start_pts=0, end_pts=None, pts_unit='sec') + video, _, _ = read_video(gt["video_path"], start_pts=0, end_pts=None, pts_unit='sec') video = video.numpy() - selected_frames = torch.linspace( - 0, video.shape[0] - 1, num_frames).long() + selected_frames = torch.linspace(0, video.shape[0] - 1, num_frames).long() video_frames = video[selected_frames] if num_frames == 1: video_frames = video_frames[None] - imgs = list(itertools.chain.from_iterable( - get_visual_transform( - img, img_h, img_w, use_tiling, max_num_tiles, - use_thumbnail, augment=False) for img in video_frames)) + imgs = list( + itertools.chain.from_iterable( + get_visual_transform( + img, img_h, img_w, use_tiling, max_num_tiles, use_thumbnail, augment=False + ) + for img in video_frames + ) + ) for question in gt["questions"]: # Very hacky, but we essentially re-create gt holding only the @@ -324,7 +329,7 @@ def get_evaluation_dataset( "video_category": gt["video_category"], "video_subcategory": gt["video_subcategory"], "url": gt["url"], - "questions": [question] + "questions": [question], } images.append(imgs) tile_counts.append(torch.tensor([len(imgs)], dtype=torch.int)) @@ -336,26 +341,30 @@ def get_evaluation_dataset( return images, tile_counts, samples, sample_ids, questions, answers -def generate_samples(model): +def generate_samples(model, config: EvaluationConfig): """Text generation using a trained vision language model.""" args = get_args() images, tile_counts, samples, sample_ids, questions, answers = get_evaluation_dataset( - args.task, - args.input_image_path, - args.gt_path, + config.task, + config.input_image_path, + config.gt_path, args.img_h, args.img_w, args.use_tiling, args.max_num_tiles, args.use_thumbnail, - args.num_samples_per_partition, - args.num_partitions, - args.partition_id, - args.num_frames + config.num_samples_per_partition, + config.num_partitions, + config.partition_id, + args.num_frames, + ) + + num_image_embeddings_per_tile = get_num_image_embeddings( + args.img_h, args.img_w, args.patch_dim, args.disable_vision_class_token, 1 ) num_img_embeddings_per_tile = get_num_image_embeddings( - args.img_h, args.img_w, args.patch_dim, - args.disable_vision_class_token, 1) + args.img_h, args.img_w, args.patch_dim, args.disable_vision_class_token, 1 + ) num_samples = len(sample_ids) idx = 0 while idx < num_samples: @@ -363,21 +372,20 @@ def generate_samples(model): num_tiles = tile_counts[idx].cuda() sample_id = sample_ids[idx] - prompt = get_prompt(args.task, questions, idx, args.prompt_format) + prompt = get_prompt(config.task, questions, idx, config.prompt_format) - forward_step = partial( - VLMForwardStep, num_img_embeddings_per_tile, imgs, num_tiles) + forward_step = partial(VLMForwardStep, num_img_embeddings_per_tile, imgs, num_tiles) if torch.distributed.get_rank() == 0: resp_sentences, _, _, _ = generate_and_post_process( model, forward_step=forward_step, prompts=[prompt], - tokens_to_generate=args.out_seq_length, - top_k_sampling=args.top_k, - top_p_sampling=args.top_p, + tokens_to_generate=config.out_seq_length, + top_k_sampling=config.top_k, + top_p_sampling=config.top_p, add_BOS=False, - temperature=args.temperature, + temperature=config.temperature, random_seed=args.seed, detokenize_segments=False, ) @@ -386,29 +394,29 @@ def generate_samples(model): output = {"sample_id": sample_id, "prompt": prompt} output_name = "" - if args.task == "captioning": + if config.task == "captioning": output_name = "caption" - elif args.task in ("TextVQA", "VQAv2", "ChartQA"): + elif config.task in ("TextVQA", "VQAv2", "ChartQA"): output_name = "answer" - elif args.task in ("MMMU"): + elif config.task in ("MMMU"): output_name = "text" - elif args.task == "VideoMME": + elif config.task == "VideoMME": output_name = "response" output = questions[idx] - generated = get_generated(prompt, args.prompt_format, generation) - if args.task == "VideoMME": + generated = get_generated(prompt, config.prompt_format, generation) + if config.task == "VideoMME": output["questions"][0][output_name] = generated else: output[output_name] = generated - if args.task == "captioning": + if config.task == "captioning": output["ground_truth"] = answers[sample_id] - elif args.task in ("TextVQA", "VQAv2"): + elif config.task in ("TextVQA", "VQAv2"): output["gt_answer"] = [ans for ans in answers[idx]] - elif args.task == "ChartQA": + elif config.task == "ChartQA": output["gt_answer"] = [answers[idx]] - elif args.task == "MMMU": + elif config.task == "MMMU": sample = samples[idx] prediction = generated @@ -429,27 +437,63 @@ def generate_samples(model): idx += 1 -def generate_and_write_samples(model): - """Generate text and write to an output file.""" +def get_evaluation_config(): + """Get evaluation config from a config file or command-line arguments.""" args = get_args() + if args.config_path: + with open(args.config_path, "r") as f: + config_dict = yaml.safe_load(f) + + config = EvaluationConfig(**config_dict) + else: + config = EvaluationConfig( + task=args.task, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + out_seq_length=args.out_seq_length, + output_path=args.output_path, + input_image_path=args.input_image_path, + gt_path=args.gt_path, + num_partitions=args.num_partitions, + partition_id=args.partition_id, + num_samples_per_partition=args.num_samples_per_partition, + prompt_format=args.prompt_format, + ) + + # Default output path if not defined... + if not config.output_path: + os.makedirs("generated", exist_ok=True) + config.output_path = "generated/" + args.language_model_type - for output in generate_samples(model): + return config + + +def generate_and_write_samples(model, config): + """Generate text and write to an output file.""" + for output in generate_samples(model, config): if torch.distributed.get_rank() == 0: - with open(args.output_path, 'a') as f: + with open(config.output_path, 'a') as f: f.write(json.dumps(output) + "\n") class VLMForwardStep(ForwardStep): """Inference forward step for a multimodal model.""" - def __init__(self, num_img_embeddings_per_tile, images, num_tiles, model, - max_batch_size, max_sequence_length): + def __init__( + self, + num_img_embeddings_per_tile, + images, + num_tiles, + model, + max_batch_size, + max_sequence_length, + ): """Create multimodal forward step.""" total_num_tiles = torch.sum(num_tiles).item() - num_img_embeddings = num_img_embeddings_per_tile * total_num_tiles + num_img_embeddings = num_img_embeddings_per_tile * total_num_tiles - super().__init__( - model, max_batch_size, max_sequence_length + num_img_embeddings) + super().__init__(model, max_batch_size, max_sequence_length + num_img_embeddings) self._images = images self._num_tiles = num_tiles @@ -461,6 +505,7 @@ def _forward(self, tokens, position_ids, attention_mask): attention_mask=None, inference_params=self.inference_params, num_image_tiles=self._num_tiles, + runtime_gather_output=True, ) def __call__(self, tokens, position_ids, attention_mask): @@ -532,20 +577,19 @@ def get_prompt(task, questions, idx, prompt_format): question = ( "Select the best answer to the following multiple-choice " "question based on the video. Respond with only the letter " - "(A, B, C, or D) of the correct option.\n") - question += (questions[idx]["questions"][0]["question"] + "\n") - question += (questions[idx]["questions"][0]["choices"][0] + "\n") - question += (questions[idx]["questions"][0]["choices"][1] + "\n") - question += (questions[idx]["questions"][0]["choices"][2] + "\n") - question += (questions[idx]["questions"][0]["choices"][3] + "\n") + "(A, B, C, or D) of the correct option.\n" + ) + question += questions[idx]["questions"][0]["question"] + "\n" + question += questions[idx]["questions"][0]["choices"][0] + "\n" + question += questions[idx]["questions"][0]["choices"][1] + "\n" + question += questions[idx]["questions"][0]["choices"][2] + "\n" + question += questions[idx]["questions"][0]["choices"][3] + "\n" if prompt_format == "llama3": prompt = "<|start_header_id|>system<|end_header_id|>\n\nAnswer the questions.<|eot_id|>{}<|start_header_id|>user<|end_header_id|>\n\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" prompt = prompt.format("", question) elif prompt_format == "mistral": - prompt = "\n{}".format( - question - ) + prompt = "\n{}".format(question) return prompt @@ -617,9 +661,12 @@ def wrapped_model_provider(pre_process, post_process): _ = load_checkpoint(model, None, None) model = model[0] + model.eval() - generate_and_write_samples(model) + config = get_evaluation_config() + + generate_and_write_samples(model, config) if __name__ == "__main__": diff --git a/examples/multimodal/sft_mistral_clip.sh b/examples/multimodal/sft_mistral_clip.sh index 93a0a91366..46fc996055 100755 --- a/examples/multimodal/sft_mistral_clip.sh +++ b/examples/multimodal/sft_mistral_clip.sh @@ -37,7 +37,6 @@ fi CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}/checkpoints" DATA_TRAIN="${SOURCE}/examples/multimodal/sft_dataset.yaml" -DATA_VALID="${SOURCE}/examples/multimodal/sft_dataset.yaml" DEBUG=0 if [[ $DEBUG -eq 1 ]]; then @@ -101,7 +100,6 @@ OPTIONS=" \ --tokenizer-type HuggingFaceTokenizer \ --tokenizer-model ${WORKSPACE}/${TOKENIZER_MODEL} \ --data-path ${DATA_TRAIN} \ - --valid-path ${DATA_VALID} \ --prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \ --save-interval 500 \ --save ${FINETUNE_DIR} \ diff --git a/examples/multimodal/text_generation_mistral_clip.sh b/examples/multimodal/text_generation_mistral_clip.sh index 30d1b06ab4..b78969ab59 100755 --- a/examples/multimodal/text_generation_mistral_clip.sh +++ b/examples/multimodal/text_generation_mistral_clip.sh @@ -4,7 +4,6 @@ export NCCL_IB_SL=1 export CUDA_DEVICE_MAX_CONNECTIONS=1 export NVTE_APPLY_QK_LAYER_SCALING=0 -INPUT_METADATA_PATH="placeholder" GROUNDTRUTH_PATH="placeholder" NUM_FRAMES=1 @@ -15,11 +14,6 @@ while [[ $# -gt 0 ]]; do shift shift ;; - --input-metadata-path) - INPUT_METADATA_PATH="$2" - shift - shift - ;; --num-frames) NUM_FRAMES="$2" shift @@ -112,7 +106,6 @@ do --no-load-rng \ --no-load-optim \ --input-image-path ${INPUT_IMAGE_PATH} \ - --input-metadata-path ${INPUT_METADATA_PATH} \ --num-partitions ${NUM_PARTITIONS} \ --partition-id ${PARTITION_ID} \ --output-path ${OUTPUT_PATH}-${TASK}-${PARTITION_ID}.jsonl \ diff --git a/examples/multimodal/train.py b/examples/multimodal/train.py index e1cad7814e..386cdc03d0 100644 --- a/examples/multimodal/train.py +++ b/examples/multimodal/train.py @@ -1,131 +1,29 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. """Pretrain or SFT multimodal.""" -from copy import deepcopy -from functools import partial +import json import os import sys -import warnings +from functools import partial import torch +import yaml sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))) -from megatron.training import get_args, get_timers, get_tokenizer, print_rank_0 -from megatron.training.arguments import core_transformer_config_from_args +from config import EvaluationConfig +from dataloader_provider import train_valid_test_dataloaders_provider +from evaluate_textvqa import textvqa_eval +from model import model_provider +from multimodal_args import add_multimodal_extra_args +from run_text_generation import generate_samples, patch_tokenizer + from megatron.core import mpu, tensor_parallel from megatron.core.enums import ModelType -from megatron.core.parallel_state import get_tensor_model_parallel_rank -from config import get_language_model_config, get_vision_model_config, get_vision_projection_config -from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings from megatron.core.models.multimodal.llava_model import LLaVAModel -from layer_specs import get_layer_spec, get_mlp_module_spec, get_layer_spec_te -from megatron.training import pretrain -from dataloader_provider import train_valid_test_dataloaders_provider - -def model_provider( - pre_process=True, post_process=True, add_encoder=True, add_decoder=True, - parallel_output=True) -> LLaVAModel: - """Builds the model. - - Args: - pre_process (bool): Include the embedding layer in the gpt decoder (used with pipeline parallelism). Defaults to True. - post_process (bool): Include an output layer and a layernorm in the gpt decoder (used with pipeline parallelism). Defaults to True. - add_encoder (bool): Construct the encoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the encoder - will live on only a subset of the pipeline stages (specifically, only the first stage). - add_decoder (bool): Construct the decoder module (used with pipeline parallelism). Defaults to True. When we use pipelining, the decoder - will live on only a subset of the pipeline stages (specifically, every stage after the first one). - parallel_output (bool): Enable parallel model output. - - Returns: - model: A multimodal model. - """ - args = get_args() - - use_te = args.use_te - - print_rank_0('building a multimodal model ...') - - num_image_embeddings = get_num_image_embeddings(args.img_h, args.img_w, args.patch_dim, args.disable_vision_class_token, 1) - old_seq_length = args.seq_length - args.seq_length = args.encoder_seq_length = num_image_embeddings - if torch.distributed.get_rank() == 0 and old_seq_length != args.seq_length: - warnings.warn(f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})") - - max_num_image_embeddings = (args.max_num_tiles + int(args.use_thumbnail)) * num_image_embeddings - - assert args.decoder_seq_length is not None, "Please provide --decoder-seq-length to set the language model sequence length" - assert args.decoder_seq_length > max_num_image_embeddings, "Language model sequence length must be greater than the maximum number of image embeddings" - if args.decoder_seq_length > args.max_position_embeddings: - args.max_position_embeddings = args.decoder_seq_length - warnings.warn(f"Expanded max_position_embeddings to {args.max_position_embeddings} to accommodate the maximum language model sequence length") - - base_config = core_transformer_config_from_args(get_args()) - base_config.language_model_type = args.language_model_type - base_config.vision_model_type = args.vision_model_type - base_config.calculate_per_token_loss = True - - language_config = deepcopy(base_config) - language_config = get_language_model_config(language_config) - - if use_te: - language_transformer_layer_spec = get_layer_spec_te(is_vit=False) # TENorm detects LayerNorm/RMS automatically. - else: - language_transformer_layer_spec = get_layer_spec(is_vit=False, normalization=language_config.normalization) - - vision_config = deepcopy(base_config) - vision_config = get_vision_model_config(vision_config, apply_query_key_layer_scaling=args.apply_query_key_layer_scaling) - - vision_model_type = args.vision_model_type - if vision_model_type == "clip": - if use_te: - vision_transformer_layer_spec = get_layer_spec_te(is_vit=True) # TENorm detects LayerNorm/RMS automatically. - else: - vision_transformer_layer_spec = get_layer_spec(is_vit=True, normalization=vision_config.normalization) - else: - raise RuntimeError("unsupported vision model type", vision_model_type) - - vision_projection_config = deepcopy(base_config) - vision_projection_config = get_vision_projection_config(vision_projection_config, language_config.hidden_size) - - if args.encoder_pipeline_model_parallel_size > 0: - assert args.encoder_pipeline_model_parallel_size == 1, "vision model and projection can only live on 1 pipeline stage." - vision_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size - vision_projection_config.pipeline_model_parallel_size = args.encoder_pipeline_model_parallel_size - if args.encoder_tensor_model_parallel_size > 0: - vision_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size - vision_projection_config.tensor_model_parallel_size = args.encoder_tensor_model_parallel_size - - vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules - - model = LLaVAModel( - language_transformer_config=language_config, - language_transformer_layer_spec=language_transformer_layer_spec, - language_vocab_size=args.padded_vocab_size, - language_max_sequence_length=args.decoder_seq_length, - vision_transformer_config=vision_config, - vision_transformer_layer_spec=vision_transformer_layer_spec, - drop_vision_class_token=args.disable_vision_class_token, - vision_projection_config=vision_projection_config, - vision_projection_layer_spec=vision_projection_layer_spec, - vision_projection_type="mlp", - allow_missing_vision_projection_checkpoint=args.allow_missing_vision_projection_checkpoint, - parallel_output=parallel_output, - language_position_embedding_type=args.position_embedding_type, - language_rotary_percent=args.rotary_percent, - pre_process=pre_process, - post_process=post_process, - add_encoder=add_encoder, - add_decoder=add_decoder, - img_h=args.img_h, - img_w=args.img_w, - patch_dim=args.patch_dim, - language_rotary_base=args.rotary_base, - ) - - model.freeze(freeze_language_model=args.freeze_LM, freeze_vision_model=args.freeze_ViT, freeze_vision_projection=False) - - return model +from megatron.core.parallel_state import get_tensor_model_parallel_rank +from megatron.training import get_args, get_timers, get_tokenizer, pretrain +from megatron.training.utils import is_last_rank def get_batch(data_iterator): @@ -314,32 +212,6 @@ def forward_step(data_iterator, model: LLaVAModel): return output_tensor, partial(loss_func, loss_mask) -def add_multimodal_extra_args(parser): - """Extra arguments.""" - group = parser.add_argument_group(title='multimodal arguments') - group.add_argument('--valid-path', nargs='*', default=None, - help='Path to the training dataset. Accepted format:' - '1) a single data path, 2) multiple datasets in the' - 'form: dataset1-weight dataset1-path dataset2-weight ' - 'dataset2-path ...') - group.add_argument('--dataset-config', type=str, default=None) - group.add_argument("--prompt-path", type=str, default=None) - group.add_argument('--freeze-LM', action='store_true', default=False) - group.add_argument('--freeze-ViT', action='store_true', default=False) - group.add_argument('--language-model-type', type=str, required=True) - group.add_argument('--vision-model-type', type=str, default="clip") - group.add_argument("--disable-vision-class-token", action="store_true", default=False) - group.add_argument("--allow-missing-vision-projection-checkpoint", action="store_true", default=False) - group.add_argument("--use-te", action="store_true", default=False) - group.add_argument("--dataloader-save", type=str, default=None, help="Energon dataloader state save path") - group.add_argument("--use-tiling", action="store_true", default=False, help="Use input image tiling") - group.add_argument("--max-num-tiles", type=int, default=1, help="Maximum number of image tiles") - group.add_argument("--use-thumbnail", action="store_true", default=False, help="Add image thumbnail as a tile") - group.add_argument("--dataloader-seq-length", type=int, help="Make dataloader to produce sequences of specific length.") - group.add_argument("--num-frames", type=int, default=1, help="Number of frames to regularly sample from the video as input to the model.") - - return parser - def llava_embedding_ranks(pp_ranks): """LLava's embedding ranks consist of the decoder's first and last ranks (ie, the ViT has no embeddings). @@ -375,6 +247,64 @@ def llava_position_embedding_ranks(pp_ranks): return [pp_ranks[epp]] + +def run_online_eval(model): + """Run an evaluation benchmark during training.""" + args = get_args() + + # Online evaluation config is not defined. Do nothing. + if not args.online_evaluation_config: + return [] + + with open(args.online_evaluation_config, "r") as f: + config_dict = yaml.safe_load(f) + + config = EvaluationConfig(**config_dict) + + patch_tokenizer(args) + + # The inference code assumes the first rank is the leader. + # Tensorboard writer is on the last rank. + # We must write to a storage space that all ranks see. + output_dir = os.path.join(args.save, "online_eval") + os.makedirs(output_dir, exist_ok=True) + config.output_path = os.path.join(output_dir, f"{config.task}.jsonl") + + if torch.distributed.get_rank() == 0: + output_file = open(config.output_path, "w") + + with torch.no_grad(): + for output in generate_samples(model[0].module, config): + if torch.distributed.get_rank() == 0: + output_file.write(json.dumps(output) + "\n") + + if torch.distributed.get_rank() == 0: + output_file.close() + + # Make sure the first rank is done writing so that the last rank can run eval. + torch.distributed.barrier() + + if not is_last_rank(): + return [] + + if config.task.lower() == "textvqa": + avg_acc = textvqa_eval(config.output_path) + + return [{"textvqa accuracy": avg_acc}] + else: + raise NotImplementedError(f"online evaluation of {config.task} not implemented yet") + + +def write_online_eval_to_tensorboard(data, iteration, writer): + """Write online evaluation data to Tensorboard.""" + if not writer: + return + + for item in data: + for k, v in item.items(): + writer.add_scalar(k, v, iteration) + + if __name__ == "__main__": train_valid_test_dataloaders_provider.is_distributed = True @@ -385,6 +315,8 @@ def llava_position_embedding_ranks(pp_ranks): forward_step, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, extra_args_provider=add_multimodal_extra_args, + process_non_loss_data_func=write_online_eval_to_tensorboard, get_embedding_ranks=llava_embedding_ranks, get_position_embedding_ranks=llava_position_embedding_ranks, + non_loss_data_func=run_online_eval ) diff --git a/megatron/core/models/gpt/gpt_model.py b/megatron/core/models/gpt/gpt_model.py index 20f83976c4..b5f7ce51e9 100644 --- a/megatron/core/models/gpt/gpt_model.py +++ b/megatron/core/models/gpt/gpt_model.py @@ -185,12 +185,17 @@ def forward( inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, extra_block_kwargs: dict = None, + runtime_gather_output: Optional[bool] = None, ) -> Tensor: """Forward function of the GPT Model This function passes the input tensors through the embedding layer, and then the decoeder and finally into the post processing layer (optional). It either returns the Loss values if labels are given or the final hidden units + + Args: + runtime_gather_output (bool): Gather output at runtime. Default None means + `parallel_output` arg in the constructor will be used. """ # If decoder_input is provided (not None), then input_ids and position_ids are ignored. # Otherwise, apply embedding layer on input_ids and position_ids to get decoder_input. @@ -230,7 +235,9 @@ def forward( output_weight = None if self.share_embeddings_and_output_weights: output_weight = self.shared_embedding_or_output_weight() - logits, _ = self.output_layer(hidden_states, weight=output_weight) + logits, _ = self.output_layer( + hidden_states, weight=output_weight, runtime_gather_output=runtime_gather_output + ) if has_config_logger_enabled(self.config): payload = OrderedDict( diff --git a/megatron/core/models/multimodal/llava_model.py b/megatron/core/models/multimodal/llava_model.py index a8ddc94ced..68d963bdf9 100644 --- a/megatron/core/models/multimodal/llava_model.py +++ b/megatron/core/models/multimodal/llava_model.py @@ -429,6 +429,7 @@ def forward( inference_params: Optional[InferenceParams] = None, num_image_tiles: Optional[List[int]] = None, image_token_index: Optional[int] = IMAGE_TOKEN_INDEX, + runtime_gather_output: Optional[bool] = None, ) -> torch.Tensor: """Forward function of the LLaVA model. @@ -445,6 +446,8 @@ def forward( inference_params (InferenceParams): Inference-time parameters including KV cache. num_image_tiles (list of int): Number of tiles per image. Default 1 tile per image. image_token_index (int): ID for input images. + runtime_gather_output (bool): Gather output at runtime. Default None means + `parallel_output` arg in the constructor will be used. Returns: output (torch.Tensor): Loss of shape [b, s] if labels are provided, @@ -528,6 +531,7 @@ def forward( decoder_input=combined_embeddings, labels=new_labels, inference_params=inference_params, + runtime_gather_output=runtime_gather_output, ) if labels is None or loss_mask is None: diff --git a/megatron/core/tensor_parallel/layers.py b/megatron/core/tensor_parallel/layers.py index ff0be00bb8..61d9c7c34d 100644 --- a/megatron/core/tensor_parallel/layers.py +++ b/megatron/core/tensor_parallel/layers.py @@ -69,6 +69,8 @@ def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor): + """Set default model parallel attributes if not set explicitly already.""" + def maybe_set(attribute, value): if not hasattr(tensor, attribute): setattr(tensor, attribute, value) @@ -78,6 +80,8 @@ def maybe_set(attribute, value): def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): + """Copy model parallel attributes from one tensor to another.""" + def maybe_copy(attribute): if hasattr(source_tensor, attribute): setattr(destination_tensor, attribute, getattr(source_tensor, attribute)) @@ -219,6 +223,11 @@ def __init__( _initialize_affine_weight_gpu(self.weight, init_method, partition_dim=0, stride=1) def forward(self, input_): + """Forward. + + Args: + input_ (torch.Tensor): Input tensor. + """ if self.tensor_model_parallel_size > 1: # Build the mask. input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) @@ -278,6 +287,7 @@ class LinearWithFrozenWeight(torch.autograd.Function): @staticmethod @custom_fwd def forward(ctx, input, weight, bias, allreduce_dgrad): + """Forward with frozen weight.""" ctx.save_for_backward(weight) ctx.allreduce_dgrad = allreduce_dgrad output = torch.matmul(input, weight.t()) @@ -288,6 +298,7 @@ def forward(ctx, input, weight, bias, allreduce_dgrad): @staticmethod @custom_bwd def backward(ctx, grad_output): + """Backward with frozen weight.""" (weight,) = ctx.saved_tensors grad_input = grad_output.matmul(weight) @@ -389,6 +400,7 @@ def forward( grad_output_buffer, wgrad_deferral_limit, ): + """Forward.""" ctx.save_for_backward(input, weight) ctx.use_bias = bias is not None ctx.gradient_accumulation_fusion = gradient_accumulation_fusion @@ -418,6 +430,7 @@ def forward( @staticmethod @custom_bwd def backward(ctx, grad_output): + """Backward.""" input, weight = ctx.saved_tensors use_bias = ctx.use_bias grad_output_buffer = ctx.grad_output_buffer @@ -847,7 +860,12 @@ def __init__( ) ) - def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): + def forward( + self, + input_: torch.Tensor, + weight: Optional[torch.Tensor] = None, + runtime_gather_output: Optional[bool] = None, + ): """Forward of ColumnParallelLinear Args: @@ -855,6 +873,8 @@ def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): 3D tensor whose order of dimension is [sequence, batch, hidden] weight (optional): weight tensor to use, compulsory when skip_weight_param_allocation is True. + runtime_gather_output (bool): Gather output at runtime. Default None means + `gather_output` arg in the constructor will be used. Returns: - output @@ -927,7 +947,13 @@ def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None): ), allreduce_dgrad=allreduce_dgrad, ) - if self.gather_output: + + gather_output = self.gather_output + # Use the runtime gather output if it's set explicitly. + if runtime_gather_output is not None: + gather_output = runtime_gather_output + + if gather_output: # All-gather across the partitions. assert not self.sequence_parallel output = gather_from_tensor_model_parallel_region(output_parallel) diff --git a/megatron/training/training.py b/megatron/training/training.py index 7d60f41f5c..fbe4ecf079 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -205,6 +205,7 @@ def pretrain( args_defaults={}, get_embedding_ranks=None, get_position_embedding_ranks=None, + non_loss_data_func=None, ): """Main training program. @@ -233,6 +234,10 @@ def pretrain( to it. It is used for programs to add their own arguments. args_defaults: a dictionary from argument-name to argument-value. It to set already parse arguments. + get_embedding_ranks (TODO): + get_position_embedding_ranks (TODO): + non_loss_data_func (callable): A custom function to call during evaluation. + It can run e.g. benchmarks. """ # Initalize and get arguments, timers, and Tensorboard writer. @@ -356,7 +361,8 @@ def pretrain( forward_step_func, model, optimizer, opt_param_scheduler, train_data_iterator, valid_data_iterator, - process_non_loss_data_func, config, checkpointing_context) + process_non_loss_data_func, config, checkpointing_context, + non_loss_data_func) print_datetime('after training is done') @@ -381,14 +387,16 @@ def pretrain( evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model, iteration, process_non_loss_data_func, config, - verbose=True, write_to_tensorboard=not args.skip_train) + verbose=True, write_to_tensorboard=not args.skip_train, + non_loss_data_func=non_loss_data_func) if args.do_test: prefix = f'iteration {iteration} on test set' evaluate_and_print_results(prefix, forward_step_func, test_data_iterator, model, iteration, process_non_loss_data_func, config, - verbose=True, write_to_tensorboard=not args.skip_train) + verbose=True, write_to_tensorboard=not args.skip_train, + non_loss_data_func=non_loss_data_func) wandb_writer = get_wandb_writer() if wandb_writer: @@ -1095,7 +1103,7 @@ def save_checkpoint_and_time(iteration, model, optimizer, opt_param_scheduler, def train(forward_step_func, model, optimizer, opt_param_scheduler, train_data_iterator, valid_data_iterator, - process_non_loss_data_func, config, checkpointing_context): + process_non_loss_data_func, config, checkpointing_context, non_loss_data_func): """Train the model function.""" args = get_args() timers = get_timers() @@ -1331,7 +1339,8 @@ def get_e2e_base_metrics(): evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model, iteration, process_non_loss_data_func, - config, False) + config, verbose=False, write_to_tensorboard=True, + non_loss_data_func=non_loss_data_func) eval_duration += timers('eval-time').elapsed() eval_iterations += args.eval_iters timers('eval-time').stop() @@ -1456,7 +1465,8 @@ def evaluate(forward_step_func, model, process_non_loss_data_func, config, - verbose=False): + verbose=False, + non_loss_data_func=None): """Evaluation.""" args = get_args() timers = get_timers() @@ -1534,7 +1544,9 @@ def evaluate(forward_step_func, return None, None, True collected_non_loss_data = None - if process_non_loss_data_func is not None and is_last_rank(): + if non_loss_data_func is not None: + collected_non_loss_data = non_loss_data_func(model) + elif process_non_loss_data_func is not None and is_last_rank(): collected_non_loss_data = forward_backward_func( forward_step_func=forward_step_func, data_iterator=data_iterator, @@ -1562,7 +1574,7 @@ def evaluate(forward_step_func, def evaluate_and_print_results(prefix, forward_step_func, data_iterator, model, iteration, process_non_loss_data_func, config, - verbose=False, write_to_tensorboard=True): + verbose=False, write_to_tensorboard=True, non_loss_data_func=None): """Helper function to evaluate and dump results on screen.""" args = get_args() if write_to_tensorboard: @@ -1574,7 +1586,7 @@ def evaluate_and_print_results(prefix, forward_step_func, total_loss_dict, collected_non_loss_data, timelimit = evaluate( forward_step_func, data_iterator, model, - process_non_loss_data_func, config, verbose) + process_non_loss_data_func, config, verbose, non_loss_data_func) # Timelimit hit during evaluation if timelimit: return