Skip to content

Commit

Permalink
Merge branch 'trintamaki/multi-image-mmmu' into 'main'
Browse files Browse the repository at this point in the history
MMMU multi-image support

See merge request ADLR/megatron-lm!1973
  • Loading branch information
jon-barker committed Sep 19, 2024
2 parents 835af44 + 2c9bcac commit 905de33
Showing 1 changed file with 81 additions and 38 deletions.
119 changes: 81 additions & 38 deletions examples/multimodal/run_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
import os
import re
import sys
from collections import defaultdict
from functools import partial
Expand Down Expand Up @@ -257,23 +258,69 @@ def get_evaluation_dataset(

for idx in range(start_idx, end_idx):
sample = dataset[idx]
sample = process_single_sample(sample)
sample = construct_prompt(sample, config)

img = sample["image"]
imgs = get_visual_transform(
img, img_h, img_w, use_tiling, max_num_tiles, use_thumbnail, augment=False
)
single_image = True
# Use the single image approach from the MMMU repo.
if single_image:
sample = process_single_sample(sample)
sample = construct_prompt(sample, config)

images.append(imgs)
tile_counts.append(torch.tensor([len(imgs)], dtype=torch.int))
img = sample["image"]
sample_imgs = get_visual_transform(
img, img_h, img_w, use_tiling, max_num_tiles, use_thumbnail, augment=False
)
sample_num_tiles = [len(sample_imgs)]
else:
sample = construct_prompt(sample, config)

sample_imgs = []
sample_num_tiles = []

img_indices = re.findall(r"<image (\d+)", sample["final_input_prompt"])
# If there are multiple input images, we need to avoid the number of image embeddings getting too large.
adjusted_max_num_tiles = max(1, max_num_tiles // len(img_indices))

for img_idx in img_indices:
img_key = f"image_{img_idx}"
img_str = f"<image {img_idx}>"

img = sample[img_key]
assert img is not None, f"{img_str} is in prompt but not in sample images"

# Note: Only replace the current image tag.
sample["final_input_prompt"] = sample["final_input_prompt"].replace(
img_str, "<image>", 1
)

imgs = get_visual_transform(
img,
img_h,
img_w,
use_tiling,
adjusted_max_num_tiles,
use_thumbnail,
augment=False,
) # List of tiles.

sample_imgs.extend(imgs)
sample_num_tiles.append(len(imgs))

# Sanity check.
for i in range(1, 8):
assert (
f"<image {i}>" not in sample["final_input_prompt"]
), "prompt contains unhandled image tags"

images.append(sample_imgs)
tile_counts.append(torch.tensor(sample_num_tiles, dtype=torch.int))

sample_ids.append(sample['id'])

# TODO: Support multiple input images and the original image position. Note: <image> is added back in the prompt construction below.
prompt = sample['final_input_prompt']
for i in range(8):
prompt = prompt.replace(f"<image {i}>", "")
if single_image:
for i in range(8):
prompt = prompt.replace(f"<image {i}>", "")
prompt = f"<image>\n{prompt}"
questions.append(prompt)

answers.append(sample['answer'])
Expand Down Expand Up @@ -359,9 +406,6 @@ def generate_samples(model, config: EvaluationConfig):
args.num_frames,
)

num_image_embeddings_per_tile = get_num_image_embeddings(
args.img_h, args.img_w, args.patch_dim, args.disable_vision_class_token, 1
)
num_img_embeddings_per_tile = get_num_image_embeddings(
args.img_h, args.img_w, args.patch_dim, args.disable_vision_class_token, 1
)
Expand Down Expand Up @@ -404,7 +448,7 @@ def generate_samples(model, config: EvaluationConfig):
output_name = "response"
output = questions[idx]

generated = get_generated(prompt, config.prompt_format, generation)
generated = get_generated(generation, args.prompt_format)
if config.task == "VideoMME":
output["questions"][0][output_name] = generated
else:
Expand Down Expand Up @@ -513,11 +557,11 @@ def __call__(self, tokens, position_ids, attention_mask):

# On the first inference iteration, we compute image tokens.
# Update the sequence length offset by the number of image tokens.
num_images = (tokens == -200).sum().item()
num_image_tokens = (tokens == -200).sum().item()
num_tokens = tokens.size(1)
if num_tokens > 1 and num_images > 0:
if num_tokens > 1 and num_image_tokens > 0:
self.inference_params.sequence_len_offset += (
self.inference_params.key_value_memory_dict["image_tokens_count"] - num_images
self.inference_params.key_value_memory_dict["image_tokens_count"] - num_image_tokens
)

return logits
Expand All @@ -529,7 +573,9 @@ def get_prompt(task, questions, idx, prompt_format):
if prompt_format == "llama3":
prompt = "<|start_header_id|>system<|end_header_id|>\n\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n<image>\nProvide a one-sentence caption for provided image.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
elif prompt_format == "mistral":
prompt = "<image>Give a short and clear explanation of the subsequent image.\n"
prompt = (
"[INST] <image>Give a short and clear explanation of the subsequent image. [/INST]"
)
elif task == "TextVQA":
question = questions[idx]

Expand All @@ -538,7 +584,7 @@ def get_prompt(task, questions, idx, prompt_format):
question
)
elif prompt_format == "mistral":
prompt = "<image>\n{}\nAnswer the question using a single word or phrase.".format(
prompt = "[INST] <image>\n{}\nAnswer the question using a single word or phrase. [/INST]".format(
question
)
elif task == "VQAv2":
Expand All @@ -549,7 +595,7 @@ def get_prompt(task, questions, idx, prompt_format):
question
)
elif prompt_format == "mistral":
prompt = "<image>\n{}\nAnswer the question using a single word or phrase.".format(
prompt = "[INST] <image>\n{}\nAnswer the question using a single word or phrase. [/INST]".format(
question
)
elif task == "ChartQA":
Expand All @@ -560,19 +606,17 @@ def get_prompt(task, questions, idx, prompt_format):
questions
)
elif prompt_format == "mistral":
prompt = "<image>\n{}\nAnswer the question using a single word or phrase.".format(
prompt = "[INST] <image>\n{}\nAnswer the question using a single word or phrase. [/INST]".format(
question
)
elif task == "MMMU":
question = questions[idx]

if prompt_format == "llama3":
prompt = "<|start_header_id|>system<|end_header_id|>\n\nAnswer the questions.<|eot_id|>{}<|start_header_id|>user<|end_header_id|>\n\n<image>\n{}\nAnswer the question using a single word or phrase.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
prompt = prompt.format("", question)
prompt = "<|start_header_id|>system<|end_header_id|>\n\nAnswer the questions.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
prompt = prompt.format(question)
elif prompt_format == "mistral":
prompt = "<image>\n{}\nAnswer the question using a single word or phrase.".format(
question
)
prompt = "[INST] {} [/INST]".format(question)
elif task == "VideoMME":
question = (
"Select the best answer to the following multiple-choice "
Expand All @@ -594,19 +638,17 @@ def get_prompt(task, questions, idx, prompt_format):
return prompt


def get_generated(prompt, prompt_format, prompt_and_generation):
def get_generated(prompt_and_generation, prompt_format):
"""Strip prompt and other unnecessary text from generation."""
start = len(prompt.replace("<image>", ""))
if prompt_format == "llama3":
start += len("<|begin_of_text|>")
start += 1
generated = prompt_and_generation.split(
"<|start_header_id|>assistant<|end_header_id|>\n\n"
)[-1]
generated = generated.split("<|eot_id|>")[0]
elif prompt_format == "mistral":
start += len("<s><unk><s> ")
generated = prompt_and_generation.split("[/INST]")[-1]
generated = generated.split("</s>")[0]

generated = prompt_and_generation[start:]
generated = generated.replace("<s> ", "")
generated = generated.split("<|eot_id|>")[0]
generated = generated.split("</s>")[0]
generated = generated.strip()
generated = generated.split("\n\n")[0]
generated = generated.split("\n")[0]
Expand All @@ -621,15 +663,16 @@ def _decorate_tokenize(f):
# When tokenizing, replace <image> with the image token index (-200)
def wrapper(prompt):
tokens = tokenizer_image_token(args, prompt, f)

return tokens

return wrapper

def _decorate_detokenize(f):
# When detokenizing, replace image token index (-200) with a dummy value.
# When detokenizing, skip image token index.
def wrapper(tokens):
tokens = np.array(tokens)
tokens[tokens == IMAGE_TOKEN_INDEX] = 0
tokens = tokens[tokens != IMAGE_TOKEN_INDEX]
tokens = tokens.tolist()

return f(tokens)
Expand Down

0 comments on commit 905de33

Please sign in to comment.