diff --git a/examples/multimodal/evaluate_ai2d.py b/examples/multimodal/evaluate_ai2d.py new file mode 100644 index 0000000000..2d5db67b67 --- /dev/null +++ b/examples/multimodal/evaluate_ai2d.py @@ -0,0 +1,46 @@ +import argparse +import json + +from evaluate_mmmu import get_input_output_paths +from evaluate_vqav2 import compute_vqa_accuracy + + +def merge_input_files(input_path): + """Merge input files to a format compatible with the evaluator.""" + input_file_paths, output_file_path = get_input_output_paths(input_path, task="AI2D") + + results = [] + + for input_file_path in input_file_paths: + with open(input_file_path, "r") as input_file: + for line in input_file: + res = json.loads(line) + results.append( + { + "question_id": res["sample_id"], + "answer": res["answer"], + "gt_answer": res["gt_answer"], + } + ) + + with open(output_file_path, "w") as output_file: + json.dump(results, output_file) + + return output_file_path + + +def ai2d_eval(input_path): + """Run AI2D evaluation.""" + result_file_path = merge_input_files(input_path) + avg_acc = compute_vqa_accuracy(result_file_path, task="AI2D") + return avg_acc + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input-path', type=str, help="Path to input file(s)") + args = parser.parse_args() + + avg_acc = ai2d_eval(args.input_path) + + print(f"===== AI2D Accuracy {avg_acc:.2f}% =====") diff --git a/examples/multimodal/evaluate_chartqa.py b/examples/multimodal/evaluate_chartqa.py index 8ec346d0d1..e9238069d4 100644 --- a/examples/multimodal/evaluate_chartqa.py +++ b/examples/multimodal/evaluate_chartqa.py @@ -28,7 +28,7 @@ def merge_input_files(input_path): def chartqa_eval(input_path): """Run ChartQA evaluation.""" result_file_path = merge_input_files(input_path) - return compute_vqa_accuracy(result_file_path, use_chartqa_metric=True) + return compute_vqa_accuracy(result_file_path, task="ChartQA") if __name__ == "__main__": diff --git a/examples/multimodal/evaluate_mathvista.py b/examples/multimodal/evaluate_mathvista.py new file mode 100644 index 0000000000..3474c5f25e --- /dev/null +++ b/examples/multimodal/evaluate_mathvista.py @@ -0,0 +1,114 @@ +import argparse +import json +import re + +from evaluate_mmmu import get_input_output_paths +from MMMU.mmmu.utils.eval_utils import parse_multi_choice_response +from open_flamingo.eval.vqa_metric import VQAEval + + +def merge_input_files(input_path): + """Merge input files to a format compatible with the evaluator.""" + input_file_paths, output_file_path = get_input_output_paths(input_path, task="MathVista") + + results = [] + + for input_file_path in input_file_paths: + with open(input_file_path, "r") as input_file: + for line in input_file: + res = json.loads(line) + results.append(res) + + with open(output_file_path, "w") as output_file: + json.dump(results, output_file) + + return output_file_path + + +def extra_processing(text): + """Extra processing.""" + # Max decimal point capped to 2 decimal point + regex = re.compile(r'^\d+\.\d+$') + decimal = regex.findall(text) + + if len(decimal) > 0: + non_decimal = len(decimal[0].split(".")[0]) + + # if decimal values are all 0, trim them + decimal_digits = [int(d) for d in decimal[0].split(".")[1]] + if sum(decimal_digits) == 0: + text = decimal[0][:non_decimal] + else: + text = decimal[0][: non_decimal + 3] + + # remove % and trailing . + text = text.replace("%", "") + if text[-1] == ".": + text = text[:-1] + + return text + + +def extract_answer(text): + """Extract answer.""" + alphabet = re.findall(r'[a-zA-Z]+', text) + if len(alphabet) > 0 and "e+" not in text: + template = re.findall(r'answer is -*\d+\.*\d*', text) + if len(template) > 0: + text = template[0] + + numbers = re.findall(r'-*\d+\.*\d*', text) + text = numbers[0] if len(numbers) > 0 else text + + return text + + +def compute_mathvista_accuracy(result_file): + """Compute MathVista accuracy.""" + merged_results = json.load(open(result_file)) + + vqa = VQAEval(vqa=None, vqaRes=None) + acc = 0 + for res in merged_results: + pred_ans = res["answer"] + if res["question_type"] == "multi_choice": + pred_ans = parse_multi_choice_response(pred_ans, res["all_choices"], res["index2ans"]) + else: + pred_ans = vqa.processPunctuation(pred_ans) + pred_ans = vqa.processDigitArticle(pred_ans) + # Extra processing and extraction. + pred_ans = extra_processing(pred_ans) + pred_ans = extract_answer(pred_ans) + + gt_ans = res["gt_answer"] + if isinstance(gt_ans, list): + assert len(gt_ans) == 1, f"Expected 1 groundtruth, got {gt_ans}" + gt_ans = gt_ans[0] + + if res["question_type"] != "multi_choice": + gt_ans = vqa.processPunctuation(gt_ans) + gt_ans = vqa.processDigitArticle(gt_ans) + + gt_ans = extra_processing(gt_ans) + + if pred_ans == gt_ans: + acc += 1 + acc = acc / len(merged_results) * 100 + return acc + + +def mathvista_eval(input_path): + """Run MathVista evaluation.""" + result_file_path = merge_input_files(input_path) + acc = compute_mathvista_accuracy(result_file_path) + return acc + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input-path', type=str, help="Path to input file(s)") + args = parser.parse_args() + + acc = mathvista_eval(args.input_path) + + print(f"===== MathVista accuracy: {acc} =====") diff --git a/examples/multimodal/evaluate_mmmu.py b/examples/multimodal/evaluate_mmmu.py index 955be95842..66118fa905 100644 --- a/examples/multimodal/evaluate_mmmu.py +++ b/examples/multimodal/evaluate_mmmu.py @@ -40,6 +40,14 @@ def convert_to_mmmu_format(input_path): sample_id = res["sample_id"] prediction = res["prediction"] + if res["question_type"] == "multiple-choice": + from MMMU.mmmu.utils.eval_utils import parse_multi_choice_response + + prediction = parse_multi_choice_response( + prediction, res["all_choices"], res["index2ans"] + ) + + # MMMU eval script expects just a sample_id to prediction mapping. output[sample_id] = prediction with open(output_file_path, "w") as output_file: @@ -69,7 +77,7 @@ def mmmu_eval(input_path, groundtruth_path): print(output.stderr) print(output.stdout) - m = re.search("'Overall': {'num': \d, 'acc': (\d.\d+)}", output.stdout) + m = re.search("'Overall': {'num': \d+, 'acc': (\d.\d+)}", output.stdout) return float(m.group(1)) * 100.0 diff --git a/examples/multimodal/evaluate_ocrbench.py b/examples/multimodal/evaluate_ocrbench.py new file mode 100644 index 0000000000..bc2b901065 --- /dev/null +++ b/examples/multimodal/evaluate_ocrbench.py @@ -0,0 +1,129 @@ +import argparse +import json + +from evaluate_mmmu import get_input_output_paths + + +def merge_input_files(input_path): + """Merge input files to a format compatible with the evaluator.""" + input_file_paths, output_file_path = get_input_output_paths(input_path, task="OCRBench") + + results = [] + + for input_file_path in input_file_paths: + with open(input_file_path, "r") as input_file: + for line in input_file: + res = json.loads(line) + results.append(res) + + with open(output_file_path, "w") as output_file: + json.dump(results, output_file) + + return output_file_path + + +def compute_ocrbench_score(result_file): + """Compute OCRBench score.""" + merged_results = json.load(open(result_file)) + + # OCRBench score calculation is adopted from https://github.com/Yuliang-Liu/MultimodalOCR/blob/1b7713f44c91f30f64efb6d3e494c416861ef15f/example.py#L1 + # MIT License. Copyright (c) 2023 Yuliang Liu + score = { + "Regular Text Recognition": 0, + "Irregular Text Recognition": 0, + "Artistic Text Recognition": 0, + "Handwriting Recognition": 0, + "Digit String Recognition": 0, + "Non-Semantic Text Recognition": 0, + "Scene Text-centric VQA": 0, + "Doc-oriented VQA": 0, + "Doc-oriented VQA": 0, + "Key Information Extraction": 0, + "Handwritten Mathematical Expression Recognition": 0, + } + + for res in merged_results: + predict = res["answer"] + answers = res["gt_answer"] + + dataset_name = res["dataset_name"] + ocr_type = res["data_type"] + + if dataset_name == "HME100k": + if isinstance(answers, list): + for j in range(len(answers)): + answer = answers[j].strip().replace("\n", " ").replace(" ", "") + predict = predict.strip().replace("\n", " ").replace(" ", "") + if answer in predict: + score[ocr_type] += 1 + else: + answers = answers.strip().replace("\n", " ").replace(" ", "") + predict = predict.strip().replace("\n", " ").replace(" ", "") + if answers in predict: + score[ocr_type] += 1 + else: + if isinstance(answers, list): + for j in range(len(answers)): + answer = answers[j].lower().strip().replace("\n", " ") + predict = predict.lower().strip().replace("\n", " ") + if answer in predict: + score[ocr_type] += 1 + else: + answers = answers.lower().strip().replace("\n", " ") + predict = predict.lower().strip().replace("\n", " ") + if answers in predict: + score[ocr_type] += 1 + + recognition_score = ( + score['Regular Text Recognition'] + + score['Irregular Text Recognition'] + + score['Artistic Text Recognition'] + + score['Handwriting Recognition'] + + score['Digit String Recognition'] + + score['Non-Semantic Text Recognition'] + ) + final_score = ( + recognition_score + + score['Scene Text-centric VQA'] + + score['Doc-oriented VQA'] + + score['Key Information Extraction'] + + score['Handwritten Mathematical Expression Recognition'] + ) + result_log = f"""###########################OCRBench############################## +Text Recognition(Total 300): {recognition_score} +------------------Details of Recognition Score------------------- +Regular Text Recognition(Total 50): {score['Regular Text Recognition']} +Irregular Text Recognition(Total 50): {score['Irregular Text Recognition']} +Artistic Text Recognition(Total 50): {score['Artistic Text Recognition']} +Handwriting Recognition(Total 50): {score['Handwriting Recognition']} +Digit String Recognition(Total 50): {score['Digit String Recognition']} +Non-Semantic Text Recognition(Total 50): {score['Non-Semantic Text Recognition']} +---------------------------------------------------------------- +Scene Text-centric VQA(Total 200): {score['Scene Text-centric VQA']} +---------------------------------------------------------------- +Doc-oriented VQA(Total 200): {score['Doc-oriented VQA']} +---------------------------------------------------------------- +Key Information Extraction(Total 200): {score['Key Information Extraction']} +---------------------------------------------------------------- +Handwritten Mathematical Expression Recognition(Total 100): {score['Handwritten Mathematical Expression Recognition']} +----------------------Final Score------------------------------- +Final Score(Total 1000): {final_score}""" + + return result_log, final_score + + +def ocrbench_eval(input_path): + """Run OCRBench evaluation.""" + result_file_path = merge_input_files(input_path) + result_log, score = compute_ocrbench_score(result_file_path) + return result_log, score + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--input-path', type=str, help="Path to input file(s)") + args = parser.parse_args() + + result_log, _ = ocrbench_eval(args.input_path) + + print(result_log) diff --git a/examples/multimodal/evaluate_textvqa.py b/examples/multimodal/evaluate_textvqa.py index e231b8e2c2..c9bba7134b 100644 --- a/examples/multimodal/evaluate_textvqa.py +++ b/examples/multimodal/evaluate_textvqa.py @@ -35,7 +35,7 @@ def merge_input_files(input_path): def textvqa_eval(input_path): """Run TextVQA evaluation.""" result_file_path = merge_input_files(input_path) - avg_acc = compute_vqa_accuracy(result_file_path) + avg_acc = compute_vqa_accuracy(result_file_path, task="TextVQA") return avg_acc diff --git a/examples/multimodal/evaluate_vqav2.py b/examples/multimodal/evaluate_vqav2.py index 9e3b727501..0b1b9209be 100644 --- a/examples/multimodal/evaluate_vqav2.py +++ b/examples/multimodal/evaluate_vqav2.py @@ -34,7 +34,7 @@ def is_number(n: str): return False -def compute_vqa_accuracy(result_file, use_chartqa_metric=False): +def compute_vqa_accuracy(result_file, task): """Compute VQA accuracy.""" merged_results = json.load(open(result_file)) @@ -51,8 +51,8 @@ def compute_vqa_accuracy(result_file, use_chartqa_metric=False): # ChartQA uses relaxed accuracy: # "We consider an answer to be correct if it is within 5% of the gold answer. - # For non-numeric answers, we still need an exact match to consider an answer to be correct." - if use_chartqa_metric: + # For non-numeric answers, we still need an exact match to consider an answer to be correct." + if task == "ChartQA": acc = 0.0 assert len(gt) == 1, "expected exactly one groundtruth answer." gt = gt[0] @@ -66,10 +66,16 @@ def compute_vqa_accuracy(result_file, use_chartqa_metric=False): acc = 1.0 all_acc.append(acc) - else: + elif task in ("VQAv2", "TextVQA"): num_match = sum([pred == ans for ans in gt]) acc = min(1.0, num_match / 3.0) all_acc.append(acc) + elif task == "AI2D": + assert len(gt) == 1, f"Expected exactly 1 GT, got {gt}" + acc = pred == gt[0] + all_acc.append(acc) + else: + raise NotImplementedError(f"unknown task {task}") acc_avg = sum(all_acc) / len(all_acc) * 100 @@ -79,7 +85,7 @@ def compute_vqa_accuracy(result_file, use_chartqa_metric=False): def vqav2_eval(input_path): """Run VQAv2 evaluation.""" result_file = merge_input_files(input_path) - avg_acc = compute_vqa_accuracy(result_file) + avg_acc = compute_vqa_accuracy(result_file, task="VQAv2") return avg_acc diff --git a/examples/multimodal/evaluation_datasets.py b/examples/multimodal/evaluation_datasets.py new file mode 100644 index 0000000000..2334cf8344 --- /dev/null +++ b/examples/multimodal/evaluation_datasets.py @@ -0,0 +1,826 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +"""Evaluation datasets.""" +import glob +import itertools +import json +import os +import re +from collections import defaultdict + +import numpy as np +import torch +from image_processing import get_visual_transform +from PIL import Image + +from megatron.training import print_rank_0 + + +def _get_partition_bounds( + total_num_samples, num_samples_per_partition, num_partitions, partition_id +): + if num_samples_per_partition == 0: + samples_per_partition = [ + int(x) for x in np.linspace(0, total_num_samples, num_partitions + 1) + ] + return samples_per_partition[partition_id], samples_per_partition[partition_id + 1] + return num_samples_per_partition * partition_id, num_samples_per_partition * (partition_id + 1) + + +class VQADataset(torch.utils.data.Dataset): + """VQA evaluation dataset.""" + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + keys, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + ): + samples = json.load(open(gt_path, encoding='utf-8')) + if "data" in samples: + samples = samples["data"] + + # Optionally, process only a subset of the input files. + if num_partitions > 0: + lb, ub = _get_partition_bounds( + len(samples), num_samples_per_partition, num_partitions, partition_id + ) + samples = samples[lb:ub] + + self._keys = keys + self._samples = samples + self._input_image_path = input_image_path + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + + def __len__(self): + return len(self._samples) + + def __getitem__(self, idx): + sample = self._samples[idx] + + img_file = "{}/{}".format(self._input_image_path, sample[self._keys["image_id"]]) + if not os.path.exists(img_file): + img_file += ".jpg" + + if not os.path.exists(img_file): + img_file = img_file.replace('.jpg', '.png') + + img = Image.open(img_file) + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + ) + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + sample_id = idx + if "sample_id" in self._keys: + sample_id = sample[self._keys["sample_id"]] + + metadata = "" # Not used. + + return ( + torch.stack(imgs), + tile_count, + sample_id, + sample[self._keys["question"]], + sample[self._keys["answer"]], + metadata, + ) + + +class CaptioningDataset(torch.utils.data.Dataset): + """Captioning evaluation dataset.""" + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + ): + image_files = sorted(glob.glob(input_image_path + "/*")) + + # Optionally, process only a subset of the input files. + if num_partitions > 0: + lb, ub = _get_partition_bounds( + len(image_files), num_samples_per_partition, num_partitions, partition_id + ) + image_files = image_files[lb:ub] + + gts = json.load(open(gt_path)) + answers = defaultdict(list) + for gt in gts["annotations"]: + answers[gt["image_id"]].append(gt['caption']) + + self._image_files = image_files + self._answers = answers + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + + def __len__(self): + return len(self._image_files) + + def __getitem__(self, idx): + img_file = self._image_files[idx] + image_id = int(img_file.split("_")[-1].split(".")[0]) + + img = Image.open(img_file) + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + ) + + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + question = "" # Fixed for all samples. + metadata = "" # Not used. + + return torch.stack(imgs), tile_count, image_id, question, self._answers[image_id], metadata + + +class MMMUDataset(torch.utils.data.Dataset): + """MMMU evaluation dataset.""" + + def __init__( + self, + input_image_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + single_image, + ): + import datasets + from MMMU.mmmu.utils.data_utils import CAT_SHORT2LONG, load_yaml + + # The following downloads the MMMU dataset from HuggingFace and uses the API from the MMMU github repo to run MMMU evaluation. + all_mmmu_datasets = [] + + hf_datasets_cache = os.environ["HF_DATASETS_CACHE"] + assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE." + + for subject in CAT_SHORT2LONG.values(): + # Use a local copy of the dataset if exists (can be faster) or the HF one. + if os.path.exists(input_image_path): + subject_dataset = datasets.load_dataset( + os.path.join(input_image_path, subject), + split=datasets.Split.VALIDATION, + cache_dir=hf_datasets_cache, + verification_mode="no_checks", + ) + else: + subject_dataset = datasets.load_dataset( + "MMMU/MMMU", + subject, + split=datasets.Split.VALIDATION, + cache_dir=hf_datasets_cache, + ) + + all_mmmu_datasets.append(subject_dataset) + + dataset = datasets.concatenate_datasets(all_mmmu_datasets) + + dataset = [s for s in dataset if s['id'].startswith("val")] + + # Optionally, process only a subset of the input files. + if num_partitions > 0: + lb, ub = _get_partition_bounds( + len(dataset), num_samples_per_partition, num_partitions, partition_id + ) + dataset = dataset[lb:ub] + + # Using the LLaVA config from the MMMU repo. + config = load_yaml("examples/multimodal/MMMU/mmmu/configs/llava1.5.yaml") + for k, v in config.items(): + if isinstance(v, list): + assert len(v) == 1, "only one value supported." + config[k] = v[0] + + self._config = config + + self._dataset = dataset + + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._single_image = single_image + + def __len__(self): + return len(self._dataset) + + def __getitem__(self, idx): + from MMMU.mmmu.utils.data_utils import construct_prompt, process_single_sample + + sample = self._dataset[idx] + + # Use the single image approach from the MMMU repo. + if self._single_image: + sample = process_single_sample(sample) + sample = construct_prompt(sample, self._config) + + img = sample["image"] + sample_imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + ) + sample_num_tiles = [len(sample_imgs)] + else: + sample = construct_prompt(sample, self._config) + + sample_imgs = [] + sample_num_tiles = [] + + img_indices = re.findall(r"" + + img = sample[img_key] + assert img is not None, f"{img_str} is in prompt but not in sample images" + + # Note: Only replace the current image tag. + sample["final_input_prompt"] = sample["final_input_prompt"].replace( + img_str, "", 1 + ) + + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + adjusted_max_num_tiles, + self._use_thumbnail, + augment=False, + ) # List of tiles. + + sample_imgs.extend(imgs) + sample_num_tiles.append(len(imgs)) + + # Sanity check. + for i in range(1, 8): + assert ( + f"" not in sample["final_input_prompt"] + ), "prompt contains unhandled image tags" + + # MMMU specific metadata. + metadata = {"question_type": sample["question_type"]} + if sample["question_type"] == "multiple-choice": + metadata["index2ans"] = sample["index2ans"] + metadata["all_choices"] = sample["all_choices"] + + prompt = sample['final_input_prompt'] + if self._single_image: + for i in range(8): + prompt = prompt.replace(f"", "") + prompt = f"\n{prompt}" + + tile_count = torch.tensor(sample_num_tiles, dtype=torch.int) + + return ( + torch.stack(sample_imgs), + tile_count, + sample["id"], + prompt, + sample["answer"], + metadata, + ) + + +class VideoMMMEDataset(torch.utils.data.Dataset): + "Video MME evaluation dataset." + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + num_frames, + ): + ground_truth_original = json.load(open(gt_path)) + ground_truth = [] + for gt in ground_truth_original: + video_path = gt["url"] + video_path = video_path.replace("https://www.youtube.com/watch?v=", "") + video_path = video_path.replace("https://m.youtube.com/watch?v=", "") + video_path = os.path.join(input_image_path, video_path + ".mp4") + if not os.path.exists(video_path): + continue + gt["video_path"] = video_path + ground_truth.append(gt) + + ground_truth = sorted(ground_truth, key=lambda gt: gt["video_path"]) + print_rank_0(f"Found {len(ground_truth)} videos to process.") + + if num_partitions > 0: + start_idx, end_idx = _get_partition_bounds( + len(ground_truth), num_samples_per_partition, num_partitions, partition_id + ) + ground_truth = ground_truth[start_idx:end_idx] + + self._ground_truth = ground_truth + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._num_frames = num_frames + + def __len__(self): + return len(self._ground_truth) + + def __getitem__(self, idx): + from torchvision.io import read_video + + gt = self._ground_truth[idx] + + video, _, _ = read_video(gt["video_path"], start_pts=0, end_pts=None, pts_unit='sec') + video = video.numpy() + selected_frames = torch.linspace(0, video.shape[0] - 1, self._num_frames).long() + video_frames = video[selected_frames] + if self._num_frames == 1: + video_frames = video_frames[None] + + imgs = list( + itertools.chain.from_iterable( + get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + ) + for img in video_frames + ) + ) + + for question in gt["questions"]: + # Very hacky, but we essentially re-create gt holding only the + # question of interest. This is the make this generation script + # compatible with the Video MME evaluation script. + question_dict = { + "video_id": gt["video_id"], + "duration_category": gt["duration_category"], + "video_category": gt["video_category"], + "video_subcategory": gt["video_subcategory"], + "url": gt["url"], + "questions": [question], + } + + num_tiles = torch.tensor([len(imgs)], dtype=torch.int) + + answer = "" + metadata = "" + + return ( + torch.stack(imgs), + num_tiles, + question["question_id"], + question_dict, + answer, + metadata, + ) + + +class OCRBenchDataset(torch.utils.data.Dataset): + """OCRBench evaluation dataset.""" + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + ): + gt = json.load(open(gt_path, encoding='utf-8')) + + if num_partitions > 0: + start_idx, end_idx = _get_partition_bounds( + len(gt), num_samples_per_partition, num_partitions, partition_id + ) + gt = gt[start_idx:end_idx] + + self._input_image_path = input_image_path + self._gt = gt + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + + def __len__(self): + return len(self._gt) + + def __getitem__(self, idx): + img_path = os.path.join(self._input_image_path, self._gt[idx]['image_path']) + + img = Image.open(img_path) + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + ) + + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + metadata = { + "dataset_name": self._gt[idx]["dataset_name"], + "data_type": self._gt[idx]["type"], + } + + return ( + torch.stack(imgs), + tile_count, + idx, + self._gt[idx]["question"], + self._gt[idx]["answers"], + metadata, + ) + + +class MathVistaDataset(torch.utils.data.Dataset): + """MathVista evaluation dataset.""" + + def __init__( + self, + input_image_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + ): + import datasets + + hf_datasets_cache = os.environ["HF_DATASETS_CACHE"] + assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE." + + if os.path.exists(input_image_path): + dataset = datasets.load_dataset( + input_image_path, cache_dir=hf_datasets_cache, verification_mode="no_checks" + ) + else: + dataset = datasets.load_dataset( + "AI4Math/MathVista", split="testmini", cache_dir=hf_datasets_cache + ) + + if num_partitions > 0: + start_idx, end_idx = _get_partition_bounds( + len(dataset), num_samples_per_partition, num_partitions, partition_id + ) + dataset = dataset[start_idx:end_idx] + + self._dataset = dataset + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + + def __len__(self): + return len(self._dataset["pid"]) + + def __getitem__(self, idx): + # Already a PIL object. + img = self._dataset['decoded_image'][idx] + + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + ) + + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + question_id = self._dataset["pid"][idx] + question = self._dataset["question"][idx] + question_type = self._dataset["question_type"][idx] # free_form or multi_choice + query = self._dataset["query"][idx] + choices = self._dataset["choices"][idx] + answer = self._dataset["answer"][idx] + + if question_type == 'multi_choice': + start_chr = 'A' + choices_str = '' + index2ans = {} + all_choices = [] + for choice in choices: + all_choices.append(start_chr) + index2ans[start_chr] = choice + choices_str += f"{start_chr}. {choice}\n" + start_chr = chr(ord(start_chr) + 1) + + question = question + '\n' + choices_str + question = question + "Answer with the option's letter from the given choices directly." + answer = chr(ord('A') + choices.index(answer)) + else: + question = query.replace("Hint: ", "") + index2ans = {} + all_choices = [] + + metadata = { + "question_type": question_type, + "index2ans": index2ans, + "all_choices": all_choices, + } + + return torch.stack(imgs), tile_count, question_id, question, answer, metadata + + +class AI2DDataset(torch.utils.data.Dataset): + """AI2D evaluation dataset.""" + + def __init__( + self, + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + no_mask, + ): + with open(gt_path, 'r') as f: + jsonl = list(f) + + gt = [json.loads(json_str) for json_str in jsonl] + + if num_partitions > 0: + start_idx, end_idx = _get_partition_bounds( + len(gt), num_samples_per_partition, num_partitions, partition_id + ) + gt = gt[start_idx:end_idx] + + self._gt = gt + self._input_image_path = input_image_path + self._img_h = img_h + self._img_w = img_w + self._use_tiling = use_tiling + self._max_num_tiles = max_num_tiles + self._use_thumbnail = use_thumbnail + self._no_mask = no_mask + + def __len__(self): + return len(self._gt) + + def __getitem__(self, idx): + img_path = os.path.join(self._input_image_path, self._gt[idx]['image']) + if self._no_mask: + img_path.replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES") + + img = Image.open(img_path) + imgs = get_visual_transform( + img, + self._img_h, + self._img_w, + self._use_tiling, + self._max_num_tiles, + self._use_thumbnail, + augment=False, + ) + + tile_count = torch.tensor([len(imgs)], dtype=torch.int) + + metadata = "" # Not used. + + return ( + torch.stack(imgs), + tile_count, + self._gt[idx]["question_id"], + self._gt[idx]["question"], + self._gt[idx]["answer"], + metadata, + ) + + +def get_evaluation_dataset( + task, + input_image_path, + gt_path, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + num_samples_per_partition, + num_partitions, + partition_id, + num_frames, +): + """Get an evaluation dataset.""" + if task == "TextVQA": + keys = { + "image_id": "image_id", + "sample_id": "question_id", + "question": "question", + "answer": "answers", + } + + dataset = VQADataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + keys, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + ) + elif task == "VQAv2": + keys = { + "image_id": "image", + "sample_id": "question_id", + "question": "question", + "answer": "answer", + } + + dataset = VQADataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + keys, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + ) + elif task == "ChartQA": + keys = {"image_id": "imgname", "question": "query", "answer": "label"} + + dataset = VQADataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + keys, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + ) + elif task == "captioning": + dataset = CaptioningDataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + ) + elif task == 'MMMU': + # Note: single_image=True uses only one image like in the MMMU repo example. + # single_image=False uses all images in the sample. + dataset = MMMUDataset( + input_image_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + single_image=True, + ) + elif task == "VideoMME": + dataset = VideoMMMEDataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + num_frames, + ) + elif task == "OCRBench": + dataset = OCRBenchDataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + ) + elif task == "MathVista": + dataset = MathVistaDataset( + input_image_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + ) + elif task == "AI2D": + dataset = AI2DDataset( + input_image_path, + gt_path, + num_samples_per_partition, + num_partitions, + partition_id, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + no_mask=False, + ) + else: + raise NotImplementedError(f"unsupported task {task}") + + return dataset diff --git a/examples/multimodal/run_text_generation.py b/examples/multimodal/run_text_generation.py index fb3f2f14e5..47c7378e0e 100644 --- a/examples/multimodal/run_text_generation.py +++ b/examples/multimodal/run_text_generation.py @@ -1,13 +1,9 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. """Generate text using a vision language model.""" -import glob -import itertools import json import logging import os -import re import sys -from collections import defaultdict from functools import partial # Add megatron to the path. @@ -15,30 +11,18 @@ os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)) ) -import datasets -import numpy as np import torch import yaml from config import EvaluationConfig -from image_processing import get_visual_transform -from MMMU.mmmu.utils.data_utils import ( - CAT_SHORT2LONG, - construct_prompt, - load_yaml, - process_single_sample, -) -from MMMU.mmmu.utils.eval_utils import parse_multi_choice_response +from evaluation_datasets import get_evaluation_dataset from model import model_provider from multimodal_args import add_multimodal_extra_args -from PIL import Image -from torchvision.io import read_video from megatron.core import parallel_state -from megatron.core.models.multimodal.llava_model import IMAGE_TOKEN from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings from megatron.inference.text_generation.api import generate_and_post_process from megatron.inference.text_generation.forward_step import ForwardStep -from megatron.training import get_args, get_model, get_tokenizer, print_rank_0 +from megatron.training import get_args, get_model from megatron.training.checkpointing import load_checkpoint from megatron.training.initialize import initialize_megatron @@ -63,7 +47,17 @@ def add_text_generation_args(parser): group.add_argument( "--task", type=str, - choices=["captioning", "TextVQA", "VQAv2", "ChartQA", "MMMU", "VideoMME"], + choices=[ + "captioning", + "TextVQA", + "VQAv2", + "ChartQA", + "MMMU", + "VideoMME", + "OCRBench", + "MathVista", + "AI2D", + ], help="Generation task to run", ) group.add_argument( @@ -77,410 +71,6 @@ def add_text_generation_args(parser): return parser -def _get_partition_bounds( - total_num_samples, num_samples_per_partition, num_partitions, partition_id -): - if num_samples_per_partition == 0: - samples_per_partition = [ - int(x) for x in np.linspace(0, total_num_samples, num_partitions + 1) - ] - return samples_per_partition[partition_id], samples_per_partition[partition_id + 1] - return num_samples_per_partition * partition_id, num_samples_per_partition * (partition_id + 1) - - -class VQADataset(torch.utils.data.Dataset): - def __init__( - self, - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - keys, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - ): - samples = json.load(open(gt_path, encoding='utf-8')) - if "data" in samples: - samples = samples["data"] - - # Optionally, process only a subset of the input files. - if num_partitions > 0: - lb, ub = _get_partition_bounds( - len(samples), num_samples_per_partition, num_partitions, partition_id - ) - samples = samples[lb:ub] - - self._keys = keys - self._samples = samples - self._input_image_path = input_image_path - self._img_h = img_h - self._img_w = img_w - self._use_tiling = use_tiling - self._max_num_tiles = max_num_tiles - self._use_thumbnail = use_thumbnail - - def __len__(self): - return len(self._samples) - - def __getitem__(self, idx): - sample = self._samples[idx] - - img_file = "{}/{}".format(self._input_image_path, sample[self._keys["image_id"]]) - if not os.path.exists(img_file): - img_file += ".jpg" - - if not os.path.exists(img_file): - img_file = img_file.replace('.jpg', '.png') - - img = Image.open(img_file) - imgs = get_visual_transform( - img, - self._img_h, - self._img_w, - self._use_tiling, - self._max_num_tiles, - self._use_thumbnail, - augment=False, - ) - tile_count = torch.tensor([len(imgs)], dtype=torch.int) - - sample_id = idx - if "sample_id" in self._keys: - sample_id = sample[self._keys["sample_id"]] - - metadata = "" # Not used. - - return ( - torch.stack(imgs), - tile_count, - sample_id, - sample[self._keys["question"]], - sample[self._keys["answer"]], - metadata, - ) - - -class CaptioningDataset(torch.utils.data.Dataset): - def __init__( - self, - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - ): - image_files = sorted(glob.glob(input_image_path + "/*")) - - # Optionally, process only a subset of the input files. - if num_partitions > 0: - lb, ub = _get_partition_bounds( - len(image_files), num_samples_per_partition, num_partitions, partition_id - ) - image_files = image_files[lb:ub] - - gts = json.load(open(gt_path)) - answers = defaultdict(list) - for gt in gts["annotations"]: - answers[gt["image_id"]].append(gt['caption']) - - self._image_files = image_files - self._answers = answers - self._img_h = img_h - self._img_w = img_w - self._use_tiling = use_tiling - self._max_num_tiles = max_num_tiles - self._use_thumbnail = use_thumbnail - - def __len__(self): - return len(self._image_files) - - def __getitem__(self, idx): - img_file = self._image_files[idx] - image_id = int(img_file.split("_")[-1].split(".")[0]) - - img = Image.open(img_file) - imgs = get_visual_transform( - img, - self._img_h, - self._img_w, - self._use_tiling, - self._max_num_tiles, - self._use_thumbnail, - augment=False, - ) - - tile_count = torch.tensor([len(imgs)], dtype=torch.int) - - question = "" # Fixed for all samples. - metadata = "" # Not used. - - return torch.stack(imgs), tile_count, image_id, question, self._answers[image_id], metadata - - -class MMMUDataset(torch.utils.data.Dataset): - def __init__( - self, - input_image_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - single_image, - ): - # The following downloads the MMMU dataset from HuggingFace and uses the API from the MMMU github repo to run MMMU evaluation. - all_mmmu_datasets = [] - - hf_datasets_cache = os.environ["HF_DATASETS_CACHE"] - assert hf_datasets_cache != "", "Please set the environment variable HF_DATASETS_CACHE." - - for subject in CAT_SHORT2LONG.values(): - # Use a local copy of the dataset if exists (can be faster) or the HF one. - if os.path.exists(input_image_path): - subject_dataset = datasets.load_dataset( - os.path.join(input_image_path, subject), - split=datasets.Split.VALIDATION, - cache_dir=hf_datasets_cache, - verification_mode="no_checks", - ) - else: - subject_dataset = datasets.load_dataset( - "MMMU/MMMU", - subject, - split=datasets.Split.VALIDATION, - cache_dir=hf_datasets_cache, - ) - - all_mmmu_datasets.append(subject_dataset) - - dataset = datasets.concatenate_datasets(all_mmmu_datasets) - - dataset = [s for s in dataset if s['id'].startswith("val")] - - # Optionally, process only a subset of the input files. - if num_partitions > 0: - lb, ub = _get_partition_bounds( - len(dataset), num_samples_per_partition, num_partitions, partition_id - ) - dataset = dataset[lb:ub] - - # Using the LLaVA config from the MMMU repo. - config = load_yaml("examples/multimodal/MMMU/mmmu/configs/llava1.5.yaml") - for k, v in config.items(): - if isinstance(v, list): - assert len(v) == 1, "only one value supported." - config[k] = v[0] - - self._config = config - - self._dataset = dataset - - self._img_h = img_h - self._img_w = img_w - self._use_tiling = use_tiling - self._max_num_tiles = max_num_tiles - self._use_thumbnail = use_thumbnail - self._single_image = single_image - - def __len__(self): - return len(self._dataset) - - def __getitem__(self, idx): - sample = self._dataset[idx] - - # Use the single image approach from the MMMU repo. - if self._single_image: - sample = process_single_sample(sample) - sample = construct_prompt(sample, self._config) - - img = sample["image"] - sample_imgs = get_visual_transform( - img, - self._img_h, - self._img_w, - self._use_tiling, - self._max_num_tiles, - self._use_thumbnail, - augment=False, - ) - sample_num_tiles = [len(sample_imgs)] - else: - sample = construct_prompt(sample, self._config) - - sample_imgs = [] - sample_num_tiles = [] - - img_indices = re.findall(r"" - - img = sample[img_key] - assert img is not None, f"{img_str} is in prompt but not in sample images" - - # Note: Only replace the current image tag. - sample["final_input_prompt"] = sample["final_input_prompt"].replace( - img_str, "", 1 - ) - - imgs = get_visual_transform( - img, - self._img_h, - self._img_w, - self._use_tiling, - adjusted_max_num_tiles, - self._use_thumbnail, - augment=False, - ) # List of tiles. - - sample_imgs.extend(imgs) - sample_num_tiles.append(len(imgs)) - - # Sanity check. - for i in range(1, 8): - assert ( - f"" not in sample["final_input_prompt"] - ), "prompt contains unhandled image tags" - - # MMMU specific metadata. - metadata = {"question_type": sample["question_type"]} - if sample["question_type"] == "multiple-choice": - metadata["index2ans"] = sample["index2ans"] - metadata["all_choices"] = sample["all_choices"] - - prompt = sample['final_input_prompt'] - if self._single_image: - for i in range(8): - prompt = prompt.replace(f"", "") - prompt = f"\n{prompt}" - - tile_count = torch.tensor(sample_num_tiles, dtype=torch.int) - - return ( - torch.stack(sample_imgs), - tile_count, - sample["id"], - prompt, - sample["answer"], - metadata, - ) - - -class VideoMMMEDataset(torch.utils.data.Dataset): - def __init__( - self, - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - num_frames, - ): - ground_truth_original = json.load(open(gt_path)) - ground_truth = [] - for gt in ground_truth_original: - video_path = gt["url"] - video_path = video_path.replace("https://www.youtube.com/watch?v=", "") - video_path = video_path.replace("https://m.youtube.com/watch?v=", "") - video_path = os.path.join(input_image_path, video_path + ".mp4") - if not os.path.exists(video_path): - continue - gt["video_path"] = video_path - ground_truth.append(gt) - - ground_truth = sorted(ground_truth, key=lambda gt: gt["video_path"]) - print_rank_0(f"Found {len(ground_truth)} videos to process.") - - if num_partitions > 0: - start_idx, end_idx = _get_partition_bounds( - len(ground_truth), num_samples_per_partition, num_partitions, partition_id - ) - ground_truth = ground_truth[start_idx:end_idx] - - self._ground_truth = ground_truth - self._img_h = img_h - self._img_w = img_w - self._use_tiling = use_tiling - self._max_num_tiles = max_num_tiles - self._use_thumbnail = use_thumbnail - self._num_frames = num_frames - - def __len__(self): - return len(self._ground_truth) - - def __getitem__(self, idx): - gt = self._ground_truth[idx] - - video, _, _ = read_video(gt["video_path"], start_pts=0, end_pts=None, pts_unit='sec') - video = video.numpy() - selected_frames = torch.linspace(0, video.shape[0] - 1, self._num_frames).long() - video_frames = video[selected_frames] - if self._num_frames == 1: - video_frames = video_frames[None] - - imgs = list( - itertools.chain.from_iterable( - get_visual_transform( - img, - self._img_h, - self._img_w, - self._use_tiling, - self._max_num_tiles, - self._use_thumbnail, - augment=False, - ) - for img in video_frames - ) - ) - - for question in gt["questions"]: - # Very hacky, but we essentially re-create gt holding only the - # question of interest. This is the make this generation script - # compatible with the Video MME evaluation script. - question_dict = { - "video_id": gt["video_id"], - "duration_category": gt["duration_category"], - "video_category": gt["video_category"], - "video_subcategory": gt["video_subcategory"], - "url": gt["url"], - "questions": [question], - } - - num_tiles = torch.tensor([len(imgs)], dtype=torch.int) - - answer = "" - metadata = "" - - return ( - torch.stack(imgs), - num_tiles, - question["question_id"], - question_dict, - answer, - metadata, - ) - - def get_evaluation_dataloader( task, input_image_path, @@ -497,108 +87,20 @@ def get_evaluation_dataloader( num_workers, ): """Build evaluation dataset.""" - if task == "TextVQA": - keys = { - "image_id": "image_id", - "sample_id": "question_id", - "question": "question", - "answer": "answers", - } - - dataset = VQADataset( - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - keys, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - ) - elif task == "VQAv2": - keys = { - "image_id": "image", - "sample_id": "question_id", - "question": "question", - "answer": "answer", - } - - dataset = VQADataset( - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - keys, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - ) - elif task == "ChartQA": - keys = {"image_id": "imgname", "question": "query", "answer": "label"} - - dataset = VQADataset( - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - keys, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - ) - elif task == "captioning": - dataset = CaptioningDataset( - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - ) - elif task == 'MMMU': - # Note: single_image=True uses only one image like in the MMMU repo example. - # single_image=False uses all images in the sample. - dataset = MMMUDataset( - input_image_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - single_image=True, - ) - elif task == "VideoMME": - dataset = VideoMMMEDataset( - input_image_path, - gt_path, - num_samples_per_partition, - num_partitions, - partition_id, - img_h, - img_w, - use_tiling, - max_num_tiles, - use_thumbnail, - num_frames, - ) - else: - raise NotImplementedError(f"unsupported task {task}") + dataset = get_evaluation_dataset( + task, + input_image_path, + gt_path, + img_h, + img_w, + use_tiling, + max_num_tiles, + use_thumbnail, + num_samples_per_partition, + num_partitions, + partition_id, + num_frames, + ) dp_rank = parallel_state.get_data_parallel_rank() dp_world_size = parallel_state.get_data_parallel_world_size() @@ -635,7 +137,12 @@ def generate_samples(model, config: EvaluationConfig, print_output): ) num_img_embeddings_per_tile = get_num_image_embeddings( - args.img_h, args.img_w, args.patch_dim, args.vision_model_type, args.disable_vision_class_token, 1 + args.img_h, + args.img_w, + args.patch_dim, + args.vision_model_type, + args.disable_vision_class_token, + 1, ) for idx, (imgs, num_tiles, sample_id, question, answers, metadata) in enumerate(dataloader): @@ -670,13 +177,22 @@ def generate_samples(model, config: EvaluationConfig, print_output): output_name = "" if config.task == "captioning": output_name = "caption" - elif config.task in ("TextVQA", "VQAv2", "ChartQA"): + elif config.task in ( + "TextVQA", + "VQAv2", + "ChartQA", + "OCRBench", + "MathVista", + "AI2D", + ): output_name = "answer" elif config.task in ("MMMU"): output_name = "text" elif config.task == "VideoMME": output_name = "response" output = question + else: + raise NotImplementedError("no output name defined for", config.task) prompt, generated = get_prompt_and_generated( generation, args.tokenizer_prompt_format @@ -689,18 +205,25 @@ def generate_samples(model, config: EvaluationConfig, print_output): if config.task == "captioning": output["ground_truth"] = answers - elif config.task in ("TextVQA", "VQAv2"): - output["gt_answer"] = [ans for ans in answers] - elif config.task == "ChartQA": - output["gt_answer"] = [answers] + elif config.task in ( + "TextVQA", + "VQAv2", + "ChartQA", + "OCRBench", + "MathVista", + "AI2D", + ): + if isinstance(answers, str): + answers = [answers] + output["gt_answer"] = answers + + if len(metadata) > 0: + output.update(metadata) elif config.task == "MMMU": - prediction = generated - if metadata["question_type"] == "multiple-choice": - prediction = parse_multi_choice_response( - generated, metadata["all_choices"], metadata["index2ans"] - ) - - output["prediction"] = prediction + output["prediction"] = generated + output.update(metadata) + else: + raise NotImplementedError("no output processing defined for", config.task) if print_output: print(output) @@ -747,6 +270,7 @@ def get_evaluation_config(): def is_first_rank(): + """First tensor and pipeline parallel rank.""" return ( parallel_state.is_pipeline_first_stage(ignore_virtual=True) and parallel_state.get_tensor_model_parallel_rank() == 0 @@ -754,6 +278,7 @@ def is_first_rank(): def get_output_path(config, dp_rank): + """Generation output path.""" return ( f"{config.output_path}-{config.task}-dprank={dp_rank}-partition={config.partition_id}.jsonl" ) @@ -825,6 +350,7 @@ def __call__(self, tokens, position_ids, attention_mask): def get_conversation(task, question): + """Get a conversation for a given task and evaluation question.""" conversation = [] # In all cases, the tokenizer adds possible header tokens for the assistant. @@ -844,6 +370,11 @@ def get_conversation(task, question): "content": f"\n{question}\nAnswer the question using a single word or phrase.", }, ] + elif task in ("OCRBench", "MathVista", "AI2D"): + conversation = [ + {"role": "system", "content": "Answer the questions."}, + {"role": "user", "content": f"\n{question}"}, + ] elif task == "MMMU": conversation = [ {"role": "system", "content": "Answer the questions."}, diff --git a/examples/multimodal/train.py b/examples/multimodal/train.py index c3e8b13a30..9ebae0e68a 100644 --- a/examples/multimodal/train.py +++ b/examples/multimodal/train.py @@ -25,9 +25,6 @@ def get_batch(data_iterator): """Generate a batch""" - - args = get_args() - imgs = None tokens = None labels = None