Skip to content

Commit

Permalink
Merge branch 'multi_image_dataloader' into 'main'
Browse files Browse the repository at this point in the history
Multi image dataloader

See merge request ADLR/megatron-lm!2438
  • Loading branch information
ericharper committed Dec 19, 2024
2 parents 7449d66 + 7e99c5b commit aff6e38
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 66 deletions.
203 changes: 137 additions & 66 deletions examples/multimodal/dataset_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@
import bisect
import dataclasses
import json
import re
import sys
import traceback
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union

from image_processing import get_visual_transform
from PIL import Image
from torchvision.transforms import ToPILImage
import numpy as np
import torch

from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN
from megatron.core.models.multimodal.llava_model import IGNORE_INDEX, IMAGE_TOKEN, VIDEO_TOKEN
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
from megatron.energon import (
Batch,
Expand Down Expand Up @@ -175,6 +178,10 @@ def __init__(

self.img_h, self.img_w = self.args.img_h, self.args.img_w

# This map is used to reduce the number of tiles used per image if the number of tokens is
# larger than the decoder_seq_length.
self.num_tiles_degradation_map = {12:8, 8:6, 6:4, 4:2, 2:1, 1:1}

def _get_total_seq_length(self, input_ids, num_tiles):
"""Calculate expected sequence length given text tokens length and number of tiles."""
total_num_images = len(num_tiles)
Expand Down Expand Up @@ -237,7 +244,7 @@ def encode_captioning(self, sample: CaptioningSample):

prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx]
cur_prompt = "<image>\n" + cur_prompt + "\n"
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt + "\n"

caption = sample.caption.strip()

Expand Down Expand Up @@ -282,7 +289,7 @@ def encode_llava_pretrain(self, sample: VQASample):
# LLAVA training: override text-prompt with just the image.
conv = [
# Note: no system message.
{"role": "user", "content": "<image>\n"},
{"role": "user", "content": IMAGE_TOKEN + "\n"},
{"role": "assistant", "content": sample.answers},
]

Expand All @@ -307,66 +314,130 @@ def encode_llava_sft(self, sample: SimilarityInterleavedSample):
"""Encode SFT sample."""
augment = sample.__subflavors__['augmentation'] if 'augmentation' in sample.__subflavors__ else False
has_video = sample.__subflavors__['has_video'] if 'has_video' in sample.__subflavors__ else False
has_image = sample.__subflavors__['has_image'] if 'has_image' in sample.__subflavors__ else False
has_image = has_image or (hasattr(sample, "images") and len(sample.images) > 0)

if has_video:
# Grab the selected frames of the video as a tensor with shape
# fhwc: (num_frames, height, width, num_channels).
video_fhwc = sample.images[0].permute(0, 2, 3, 1)
selected_frames = torch.linspace(
0, video_fhwc.shape[0] - 1, self.args.num_frames).long()
video_frame_fhwc = video_fhwc[selected_frames]
imgs = []
for video_frame_hwc in video_frame_fhwc:
imgs += get_visual_transform(
video_frame_hwc, self.img_h, self.img_w,
self.args.use_tiling, self.args.max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type)
num_tiles = [len(imgs)]
elif has_image:
imgs = get_visual_transform(
sample.images[0], self.img_h, self.img_w, self.args.use_tiling, self.args.max_num_tiles, self.args.use_thumbnail, augment,
self.args.vision_model_type,
)
num_tiles = [len(imgs)]
else:
imgs = num_tiles = []
sample.__key__ = "{}-{}".format("no-image", sample.__key__)
has_image = False
if hasattr(sample, "images"):
# If this is a text-only sample and we are freezing the LM,
# then use a dummy input image.
if len(sample.images) == 0 and self.args.freeze_LM:
empty_img = Image.new('RGB', (self.args.img_w, self.args.img_h), (255, 255, 255))
sample.images.append(empty_img)
if len(sample.images) > 0 and not has_video:
has_image = True

conversation = []
# Note: Some tokenizers may ignore the system prompt.
conversation.append({"role": "system", "content": "Answer the questions."})

has_image_token = False

conversation = [{"role": "system", "content": "Answer the questions."}]
# Format the conversation as a list of "user" / "assistant" turns.
for text in sample.texts:
if IMAGE_TOKEN in text["value"]:
has_image_token = True

if text["from"] == "human":
role = "user"
elif text["from"] == "gpt":
role = "assistant"
else:
raise RuntimeError(f"unexpected role {text['from']} in {sample.texts}")

turn = {"role": role, "content": text["value"]}
conversation.append(turn)

# If the sample contains an image but none of the user messages has an image token,
# then add it to the first user message.
if len(imgs) > 0 and not has_image_token:
error_msg = f"unexpected role {text['from']} in {sample.texts}"
assert text["from"] in ["human", "gpt"], error_msg
conversation.append({
"role": "user" if text["from"] == "human" else "assistant",
"content": text["value"]})

# Replace the image tags <image-idx> with IMAGE_TOKEN and count the number of image tags
number_image_tags = 0
image_tag_ids_list = []
for turn in conversation:
if turn["role"] == "user":
image_tag_ids = [int(x) - 1 for x in re.findall(r"<image-(\d+)>", turn["content"])]
image_tag_ids_list.extend(image_tag_ids)
turn["content"] = re.sub(r"<image-\d+>", IMAGE_TOKEN, turn["content"])
number_image_tags += turn["content"].count(IMAGE_TOKEN)
# For videos, we replace the image tag with the video tag
if has_video:
turn["content"] = turn["content"].replace(IMAGE_TOKEN, VIDEO_TOKEN)

# We re-order the images in sample.images according to how they appear in the conversation.
if len(image_tag_ids_list) > 0:
sample.images = [sample.images[idx] for idx in image_tag_ids_list]

# If there is only one image, but several image tags, we assume all the tags refer to the
# same image and duplicate the image:
if len(sample.images) == 1 and number_image_tags > 1:
sample.images = sample.images * number_image_tags

number_of_images = len(sample.images)
# Fail if there are more image or video tags than image or videos:
error_msg = (
f"Found {number_image_tags} image tags for {number_of_images} images. {sample.texts}")
assert number_image_tags <= number_of_images, error_msg

# If there are less image of video tags than image or videos, prepend the tags to the first
# user message:
if number_image_tags < number_of_images:
for turn in conversation:
if turn["role"] == "user":
turn["content"] = f"{IMAGE_TOKEN}\n" + turn["content"]
tag_to_add = VIDEO_TOKEN if has_video else IMAGE_TOKEN
turn["content"] = tag_to_add*(number_of_images-number_image_tags) + "\n" + turn["content"]
break

input_ids, target = self.tokenizer.tokenize_conversation(conversation, True, False)

if has_image:
imgs = []
num_tiles = []
max_num_tiles = self.args.max_num_tiles
# We keep a buffer of 4 tokens for the question,
# the rest can be used for image tokens.
max_image_token_allowed = self.args.decoder_seq_length - len(input_ids) - 4
# We start by extracting as many tiles per image as possible, and decrease the max
# number of tiles if there are too many image tokens.
while True:
imgs = []
num_tiles = []
for img in sample.images:
img_tiles = get_visual_transform(
img, self.img_h, self.img_w, self.args.use_tiling, max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type)
imgs += img_tiles
num_tiles += [len(img_tiles)]
if max_num_tiles == 1:
break
if sum(num_tiles) * self.token_per_img_tile > max_image_token_allowed:
if max_num_tiles in self.num_tiles_degradation_map:
max_num_tiles = self.num_tiles_degradation_map[max_num_tiles]
else:
raise RuntimeError((
f"Tried to decrease the number of tiles {max_num_tiles} but it's not ",
f"defined in the degradation map {self.num_tiles_degradation_map}"))
else:
break
elif has_video:
# We don't use tiling for videos to limit the number of tokens.
use_tiling=False
# Grab the selected frames of the video as a tensor with shape
# fhwc: (num_frames, num_channels, height, width).
video_fchw = sample.images[0].permute(0, 1, 2, 3)
selected_frames = torch.linspace(
0, video_fchw.shape[0] - 1, self.args.num_frames).long()
video_fchw = video_fchw[selected_frames]
imgs = []
for video_chw in video_fchw:
to_pil = ToPILImage()
video_chw = to_pil(video_chw)
imgs += get_visual_transform(
video_chw, self.img_h, self.img_w, use_tiling, self.args.max_num_tiles,
self.args.use_thumbnail, augment, self.args.vision_model_type)
num_tiles = [len(imgs)]
else:
imgs = num_tiles = []

if self.is_packing_enabled:
input_ids, target = self._truncate_for_packing(input_ids, target, num_tiles)

# Some final checks with respect to the number of image tokens and images on the tokenized
# conversation. There can still be errors, for instance if a non-video sample happens to
# have our pre-defined video token, or if the packing truncation removed a necessary image
# tag.
number_image_token = np.sum(input_ids == self.img_token_id)
error_msg = (
f"Found {number_image_token} image tokens for len({num_tiles}) = {len(num_tiles)} image tiles in {conversation}.")
assert number_image_token == len(num_tiles), error_msg
error_msg = (
f"Found sum({num_tiles}) = {np.sum(num_tiles)} tiles for {len(imgs)} images in {conversation}.")
assert np.sum(num_tiles) == len(imgs), error_msg

return ImageTaskSample(
__key__=sample.__key__,
__restore_key__=sample.__restore_key__,
Expand Down Expand Up @@ -407,8 +478,8 @@ def encode_any_single_turn_vqa(self, sample):

if isinstance(sample, MultiChoiceVQASample):
cur_prompt = format_multichoice_question(sample.context, sample.choices)
if "<image>" not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt
if IMAGE_TOKEN not in cur_prompt:
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
cur_answer = format_multichoice_answer(sample.correct_choice_idx)
elif isinstance(sample, VQASample):
if 'docvqa' in sample.__key__:
Expand All @@ -423,8 +494,8 @@ def encode_any_single_turn_vqa(self, sample):

cur_prompt = cur_prompt.format(sample.context)

if "<image>" not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt
if IMAGE_TOKEN not in cur_prompt:
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt

if isinstance(sample.answers, list):
answer_list = sample.answers
Expand Down Expand Up @@ -505,11 +576,11 @@ def encode_pdf_prompt(self, sample: OCRSample) -> ImageTaskSample:
prompt_list = self.manual_prompts["DocPretraining"]["raw"]
prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx]
if "<image>" not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt
if IMAGE_TOKEN not in cur_prompt:
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt

# Make sure there is no extra <image> tag.
sample.text = sample.text.replace("<image>", "")
# Make sure there is no extra IMAGE_TOKEN tag.
sample.text = sample.text.replace(IMAGE_TOKEN, "")

caption = sample.text.strip()

Expand All @@ -526,8 +597,8 @@ def encode_ocr_ref_prompt(self, sample: OCRSample) -> ImageTaskSample:
ref = sample.text
region = sample.words_boxes

# Make sure there is no extra <image> tag
ref = ref.replace("<image>", "")
# Make sure there is no extra IMAGE_TOKEN tag
ref = ref.replace(IMAGE_TOKEN, "")

if len(region) == 4:
region = f"<box>({region[0]},{region[1]}),({region[2]},{region[3]})</box>"
Expand All @@ -550,17 +621,17 @@ def encode_ocr_ref_prompt(self, sample: OCRSample) -> ImageTaskSample:
prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx]
cur_prompt = cur_prompt.format(prompt_content)
if "<image>" not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt
if IMAGE_TOKEN not in cur_prompt:
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt

return sample, cur_prompt, answer

def bbox_coord_to_label(self, text, bbox):
"""Format bbox coordinates as text."""
assert len(bbox) == 4 or len(bbox) == 8

# Make sure there is no extra <image> tag
text = text.replace("<image>", "")
# Make sure there is no extra IMAGE_TOKEN tag
text = text.replace(IMAGE_TOKEN, "")

if len(bbox) == 4:
label_str = f"<ref>{text}</ref><box>({bbox[0]},{bbox[1]}),({bbox[2]},{bbox[3]})</box>"
Expand All @@ -582,8 +653,8 @@ def encode_ocr_prompt(self, sample: OCRSample) -> ImageTaskSample:
prompt_idx = np.random.randint(len(prompt_list))
cur_prompt = prompt_list[prompt_idx]

if "<image>" not in cur_prompt:
cur_prompt = "<image>\n" + cur_prompt
if IMAGE_TOKEN not in cur_prompt:
cur_prompt = IMAGE_TOKEN + "\n" + cur_prompt
cur_answer = answer

return sample, cur_prompt, cur_answer
Expand Down
1 change: 1 addition & 0 deletions megatron/core/models/multimodal/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# Image token index can be tokenizer dependent so the default value does not work in all cases.
DEFAULT_IMAGE_TOKEN_INDEX = -200
IMAGE_TOKEN = "<image>"
VIDEO_TOKEN = "<video>"


# Note: This is under development and may be missing features.
Expand Down

0 comments on commit aff6e38

Please sign in to comment.