Skip to content

Commit

Permalink
Merge branch 'trintamaki/multi-image-multi-tile' into 'main'
Browse files Browse the repository at this point in the history
Support multi-image multi-tile input in LLaVA

See merge request ADLR/megatron-lm!1943
  • Loading branch information
jaredcasper committed Aug 23, 2024
2 parents 4ff9e66 + 086cd85 commit e32b60b
Show file tree
Hide file tree
Showing 2 changed files with 216 additions and 123 deletions.
150 changes: 88 additions & 62 deletions megatron/core/models/multimodal/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from collections import namedtuple
from functools import partial
from typing import List
from typing import List, Optional

import torch

Expand Down Expand Up @@ -204,6 +204,7 @@ def _preprocess_data(
labels,
use_inference_kv_cache,
image_token_index,
num_image_tiles,
):
"""Preprocess input data before input to language model.
Expand All @@ -217,7 +218,8 @@ def _preprocess_data(
- final_labels = [1, -100, 2, 3, 4]
- final_loss_mask = [1, 0, 0, 1, 1]
This function also handles the case where the input does not contain an image (text-only sample).
This function also handles the case where the input does not contain an image (text-only sample). It also handles the case where a single input
image is split into multiple tiles.
If pipeline parallelism is not used, then self.pre_process and self.post_process are both True and we update both
input embeddings, labels and loss masks (if available).
Expand All @@ -244,9 +246,7 @@ def _preprocess_data(
if use_inference_kv_cache:
return language_embeddings, loss_mask, labels

img_seq_len = (
self._img_seq_len - 1
) # Adjust by -1 to account for the removed image token index.
img_seq_len = self._img_seq_len
batch_size, text_seq_len = input_ids.shape

has_labels = labels is not None
Expand All @@ -255,41 +255,60 @@ def _preprocess_data(
labels.shape == loss_mask.shape
), f"mismatching labels shape {labels.shape} and loss mask shape {loss_mask.shape}"

# Create indices for new text and label positions.
with torch.no_grad():
image_token_mask = input_ids == image_token_index
num_image_tokens = torch.sum(image_token_mask, dim=-1)
num_images_per_sample = torch.sum(image_token_mask, dim=-1)

max_seq_len = (num_image_tokens.max() * img_seq_len) + text_seq_len
# Number of tiles per sample.
num_image_tiles_batch = num_image_tiles.split(num_images_per_sample.tolist(), dim=0)
num_image_tiles_batch = torch.tensor(
[x.sum() for x in num_image_tiles_batch], device=input_ids.device
)

# Sequence length for each sample is the image sequence length multiplied by the number of tiles for that image, minus image token indices,
# plus text sequence length.
seq_lens = num_image_tiles_batch * img_seq_len - num_images_per_sample + text_seq_len
max_seq_len = seq_lens.max()
batch_indices, non_image_indices = torch.where(input_ids != image_token_index)

# New position ids for the text tokens, shifted by the image sequence length.
# E.g. for input_ids = [-200, 1, 2, 3] and img_seq_len = 576, we get new_position_ids = [576, 577, 578, 579].
# text_position_ids are then [577, 578, 579].
image_token_mask_lens = image_token_mask.int().clone()
# -1 is for the removed image token index.
image_token_mask_lens[image_token_mask] = num_image_tiles * img_seq_len - 1
# +1 is needed here for the cumulative sum. -1 is adjusting for zero-based indexing.
new_position_ids = torch.cumsum((image_token_mask * img_seq_len + 1), dim=-1) - 1
new_position_ids = torch.cumsum((image_token_mask_lens + 1), dim=-1) - 1
text_position_ids = new_position_ids[batch_indices, non_image_indices]

# Repeat the same for labels, which have the image token index shifted to left by one.
# An exception is an input sequence starting with an image token in which case
# the image token is not present in labels so we correct for it.
# Labels are shifted to left by one. So, shift text position ids and non-image indices to left by one.
if has_labels:
edge = input_ids[:, 0] == image_token_index
label_image_token_mask = labels == image_token_index
label_batch_indices, label_non_image_indices = torch.where(
labels != image_token_index
)
label_text_position_ids = text_position_ids - 1
valid_label_text_position_ids = label_text_position_ids >= 0
label_text_position_ids = label_text_position_ids[valid_label_text_position_ids]

new_label_position_ids = (
torch.cumsum((label_image_token_mask * img_seq_len + 1), dim=-1) - 1
)
# If the input sequence starts with an image token, then that image token is not present in the labels
# and we need to shift the label position ids by the image sequence length.
new_label_position_ids[edge] += img_seq_len
label_text_position_ids = new_label_position_ids[
label_batch_indices, label_non_image_indices
]
label_batch_indices = batch_indices[valid_label_text_position_ids]

# Initialize output tensors.
label_non_image_indices = non_image_indices - 1
valid_label_non_image_indices = label_non_image_indices >= 0
label_non_image_indices = label_non_image_indices[valid_label_non_image_indices]

# Create a mask for the image embedding positions.
images_mask = torch.full(
(batch_size, max_seq_len), True, dtype=torch.bool, device=input_ids.device
)
# No images in the text positions.
images_mask[batch_indices, text_position_ids] = False
# Samples can have different amount of images tokens. new_position_ids[:, -1] gives the last text position id for each sample.
# Padding is needed when the number of image tokens differs.
first_padding_idx = new_position_ids[:, -1] + 1
images_mask[
torch.arange(max_seq_len, device=first_padding_idx.device).repeat(batch_size, 1)
>= first_padding_idx.unsqueeze(1)
] = False

# Create the final input embedding (if this is the first language model stage).
final_embedding = None
if self.pre_process:
embed_dim = language_embeddings.shape[-1]
Expand All @@ -301,6 +320,15 @@ def _preprocess_data(
device=image_embeddings.device,
)

# Put text embeddings to the text positions in the result tensor.
final_embedding[batch_indices, text_position_ids] = language_embeddings[
batch_indices, non_image_indices
]

# Put image embeddings to image positions.
final_embedding[images_mask] = image_embeddings.reshape(-1, embed_dim).contiguous()

# Create the final labels and loss mask (if this is the last language model stage).
final_labels, final_loss_mask = None, None
if has_labels:
final_labels = torch.full(
Expand All @@ -310,46 +338,36 @@ def _preprocess_data(
(batch_size, max_seq_len), 0, dtype=loss_mask.dtype, device=loss_mask.device
)

# Put text embeddings to the text positions in the result tensor.
if self.pre_process:
final_embedding[batch_indices, text_position_ids] = language_embeddings[
batch_indices, non_image_indices
]

# Put text labels and loss mask to the text positions.
if has_labels:
# Put text labels and loss mask to the text positions.
final_labels[label_batch_indices, label_text_position_ids] = labels[
label_batch_indices, label_non_image_indices
]

final_loss_mask[batch_indices, text_position_ids] = loss_mask[
batch_indices, non_image_indices
]

with torch.no_grad():
# Create a mask for the image embedding positions.
images_mask = torch.full(
(batch_size, max_seq_len), True, dtype=torch.bool, device=input_ids.device
)
images_mask[batch_indices, text_position_ids] = (
False # No images in the text positions.
)
# Samples can have different amount of images tokens. new_position_ids[:, -1] gives the last text position id for each sample.
# Padding is needed when the number of image tokens differs. Compute the number of padding tokens on the right for each sample.
padding = max_seq_len - 1 - new_position_ids[:, -1]
# Mark the padding tokens on the right as False in the images mask. -1 adjusts cumulative sum to be zero-based.
images_mask &= images_mask.cumsum(dim=-1) - 1 >= padding[:, None]

if self.pre_process:
final_embedding[images_mask] = image_embeddings.reshape(-1, embed_dim).contiguous()
# For labels, we need to pick the last label index that got dropped by the shift to left.
label_extra_text_position_ids = seq_lens - 1
batch_range = torch.arange(len(label_extra_text_position_ids))
final_labels[batch_range, label_extra_text_position_ids] = labels[batch_range, -1]

if has_labels:
# Loss mask the image positions.
final_loss_mask[images_mask] = 0

# Loss mask last text position just before an image so that text token does not need to predict the first image token.
batch_image_indices, image_indices = torch.where(image_token_mask)
text_before_image_indices = torch.maximum(image_indices - 1, torch.tensor(0))
final_loss_mask[batch_image_indices, text_before_image_indices] = 0
# Indices just before image tokens. If it's -1, skip it.
before_image_indices = image_indices - 1
valid = before_image_indices >= 0
valid_batch_image_indices = batch_image_indices[valid]
valid_before_image_indices = before_image_indices[valid]
# Map those indices those position ids.
valid_before_image_indices = new_position_ids[
valid_batch_image_indices, valid_before_image_indices
]

final_loss_mask[valid_batch_image_indices, valid_before_image_indices] = 0

if final_embedding is not None and has_labels:
assert (
Expand All @@ -367,21 +385,23 @@ def forward(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: torch.Tensor = None,
loss_mask: torch.Tensor = None,
inference_params: InferenceParams = None,
image_token_index: int = IMAGE_TOKEN_INDEX,
labels: Optional[torch.Tensor] = None,
loss_mask: Optional[torch.Tensor] = None,
inference_params: Optional[InferenceParams] = None,
num_image_tiles: Optional[List[int]] = None,
image_token_index: Optional[int] = IMAGE_TOKEN_INDEX,
) -> torch.Tensor:
"""Forward function of the LLaVA model.
Args:
images (torch.Tensor): input image of shape [batch, img_h, img_w].
images (torch.Tensor): input image of shape [num_tiles, img_h, img_w]. num_tiles means the number of image tiles in this batch.
input_ids (torch.Tensor): input text ids [batch, text_seq_len].
position_ids (torch.Tensor): input text position ids [batch, text_seq_len].
attention_mask (torch.Tensor): Attention mask for the language model [batch, 1, combined_seq_len, combined_seq_len].
labels (torch.Tensor): Optional target text labels [batch, combined_seq_len].
loss_mask (torch.Tensor): Text loss mask [batch, text_seq_len].
inference_params (InferenceParams): Inference-time parameters including KV cache.
num_image_tiles (list of int): Number of tiles per image. Default None assumes 1 tile per image.
image_token_index (int): ID for input images.
Returns:
Expand All @@ -396,24 +416,25 @@ def forward(
if use_inference_kv_cache:
image_embeddings = None
elif self.add_encoder:
image_embeddings = self.vision_model(images) # [b, img_seq_len, h_vision]
image_embeddings = self.vision_model(images) # [num_tiles, img_seq_len, h_vision]
if self._drop_vision_class_token:
image_embeddings = image_embeddings[:, self.vision_model.class_token_len :, :]
# contiguous() call required as `permute` can sparsify the tensor and this breaks pipelining
image_embeddings = image_embeddings.permute(
1, 0, 2
).contiguous() # [img_seq_len, b, h_vision]
).contiguous() # [img_seq_len, num_tiles, h_vision]

# map vision model output size to language model input size.
image_embeddings = self.vision_projection(
image_embeddings
) # [img_seq_len, b, h_vision]
) # [img_seq_len, num_tiles, h_language]

# TODO: Support batched inference.
# If running inference, the language model KV cache will be updated for image token positions.
# Here we store the image tokens sequence length, which can be used as an offset to the KV cache later.
if inference_params is not None:
inference_params.key_value_memory_dict["image_tokens_count"] = (
image_embeddings.shape[0]
image_embeddings.shape[0] * image_embeddings.shape[1]
)
else:
image_embeddings = self.encoder_hidden_state
Expand All @@ -434,6 +455,10 @@ def forward(
1, 0
).contiguous() # [b, text_seq_len, h_language]

# Assume 1 tile per image if the number of tiles is not provided.
if num_image_tiles is None:
num_image_tiles = torch.ones(images.shape[0], dtype=torch.int, device=input_ids.device)

# Preprocess input, labels and loss mask.
combined_embeddings, new_labels, new_loss_mask = self._preprocess_data(
image_embeddings,
Expand All @@ -443,6 +468,7 @@ def forward(
labels,
use_inference_kv_cache,
image_token_index,
num_image_tiles,
) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len]

output = self.language_model(
Expand Down
Loading

0 comments on commit e32b60b

Please sign in to comment.