From 086cd85cf37da83006bc9bcd04cfaa39f6f586ff Mon Sep 17 00:00:00 2001 From: Tuomas Rintamaki Date: Fri, 23 Aug 2024 11:26:25 -0700 Subject: [PATCH] ADLR/megatron-lm!1943 - Support multi-image multi-tile input in LLaVA --- .../core/models/multimodal/llava_model.py | 150 ++++++++------ tests/unit_tests/models/test_llava_model.py | 189 ++++++++++++------ 2 files changed, 216 insertions(+), 123 deletions(-) diff --git a/megatron/core/models/multimodal/llava_model.py b/megatron/core/models/multimodal/llava_model.py index f15418e4b6..f1ca4ba7b2 100644 --- a/megatron/core/models/multimodal/llava_model.py +++ b/megatron/core/models/multimodal/llava_model.py @@ -2,7 +2,7 @@ import logging from collections import namedtuple from functools import partial -from typing import List +from typing import List, Optional import torch @@ -204,6 +204,7 @@ def _preprocess_data( labels, use_inference_kv_cache, image_token_index, + num_image_tiles, ): """Preprocess input data before input to language model. @@ -217,7 +218,8 @@ def _preprocess_data( - final_labels = [1, -100, 2, 3, 4] - final_loss_mask = [1, 0, 0, 1, 1] - This function also handles the case where the input does not contain an image (text-only sample). + This function also handles the case where the input does not contain an image (text-only sample). It also handles the case where a single input + image is split into multiple tiles. If pipeline parallelism is not used, then self.pre_process and self.post_process are both True and we update both input embeddings, labels and loss masks (if available). @@ -244,9 +246,7 @@ def _preprocess_data( if use_inference_kv_cache: return language_embeddings, loss_mask, labels - img_seq_len = ( - self._img_seq_len - 1 - ) # Adjust by -1 to account for the removed image token index. + img_seq_len = self._img_seq_len batch_size, text_seq_len = input_ids.shape has_labels = labels is not None @@ -255,41 +255,60 @@ def _preprocess_data( labels.shape == loss_mask.shape ), f"mismatching labels shape {labels.shape} and loss mask shape {loss_mask.shape}" + # Create indices for new text and label positions. with torch.no_grad(): image_token_mask = input_ids == image_token_index - num_image_tokens = torch.sum(image_token_mask, dim=-1) + num_images_per_sample = torch.sum(image_token_mask, dim=-1) - max_seq_len = (num_image_tokens.max() * img_seq_len) + text_seq_len + # Number of tiles per sample. + num_image_tiles_batch = num_image_tiles.split(num_images_per_sample.tolist(), dim=0) + num_image_tiles_batch = torch.tensor( + [x.sum() for x in num_image_tiles_batch], device=input_ids.device + ) + + # Sequence length for each sample is the image sequence length multiplied by the number of tiles for that image, minus image token indices, + # plus text sequence length. + seq_lens = num_image_tiles_batch * img_seq_len - num_images_per_sample + text_seq_len + max_seq_len = seq_lens.max() batch_indices, non_image_indices = torch.where(input_ids != image_token_index) # New position ids for the text tokens, shifted by the image sequence length. # E.g. for input_ids = [-200, 1, 2, 3] and img_seq_len = 576, we get new_position_ids = [576, 577, 578, 579]. # text_position_ids are then [577, 578, 579]. + image_token_mask_lens = image_token_mask.int().clone() + # -1 is for the removed image token index. + image_token_mask_lens[image_token_mask] = num_image_tiles * img_seq_len - 1 # +1 is needed here for the cumulative sum. -1 is adjusting for zero-based indexing. - new_position_ids = torch.cumsum((image_token_mask * img_seq_len + 1), dim=-1) - 1 + new_position_ids = torch.cumsum((image_token_mask_lens + 1), dim=-1) - 1 text_position_ids = new_position_ids[batch_indices, non_image_indices] - # Repeat the same for labels, which have the image token index shifted to left by one. - # An exception is an input sequence starting with an image token in which case - # the image token is not present in labels so we correct for it. + # Labels are shifted to left by one. So, shift text position ids and non-image indices to left by one. if has_labels: - edge = input_ids[:, 0] == image_token_index - label_image_token_mask = labels == image_token_index - label_batch_indices, label_non_image_indices = torch.where( - labels != image_token_index - ) + label_text_position_ids = text_position_ids - 1 + valid_label_text_position_ids = label_text_position_ids >= 0 + label_text_position_ids = label_text_position_ids[valid_label_text_position_ids] - new_label_position_ids = ( - torch.cumsum((label_image_token_mask * img_seq_len + 1), dim=-1) - 1 - ) - # If the input sequence starts with an image token, then that image token is not present in the labels - # and we need to shift the label position ids by the image sequence length. - new_label_position_ids[edge] += img_seq_len - label_text_position_ids = new_label_position_ids[ - label_batch_indices, label_non_image_indices - ] + label_batch_indices = batch_indices[valid_label_text_position_ids] - # Initialize output tensors. + label_non_image_indices = non_image_indices - 1 + valid_label_non_image_indices = label_non_image_indices >= 0 + label_non_image_indices = label_non_image_indices[valid_label_non_image_indices] + + # Create a mask for the image embedding positions. + images_mask = torch.full( + (batch_size, max_seq_len), True, dtype=torch.bool, device=input_ids.device + ) + # No images in the text positions. + images_mask[batch_indices, text_position_ids] = False + # Samples can have different amount of images tokens. new_position_ids[:, -1] gives the last text position id for each sample. + # Padding is needed when the number of image tokens differs. + first_padding_idx = new_position_ids[:, -1] + 1 + images_mask[ + torch.arange(max_seq_len, device=first_padding_idx.device).repeat(batch_size, 1) + >= first_padding_idx.unsqueeze(1) + ] = False + + # Create the final input embedding (if this is the first language model stage). final_embedding = None if self.pre_process: embed_dim = language_embeddings.shape[-1] @@ -301,6 +320,15 @@ def _preprocess_data( device=image_embeddings.device, ) + # Put text embeddings to the text positions in the result tensor. + final_embedding[batch_indices, text_position_ids] = language_embeddings[ + batch_indices, non_image_indices + ] + + # Put image embeddings to image positions. + final_embedding[images_mask] = image_embeddings.reshape(-1, embed_dim).contiguous() + + # Create the final labels and loss mask (if this is the last language model stage). final_labels, final_loss_mask = None, None if has_labels: final_labels = torch.full( @@ -310,46 +338,36 @@ def _preprocess_data( (batch_size, max_seq_len), 0, dtype=loss_mask.dtype, device=loss_mask.device ) - # Put text embeddings to the text positions in the result tensor. - if self.pre_process: - final_embedding[batch_indices, text_position_ids] = language_embeddings[ - batch_indices, non_image_indices - ] - - # Put text labels and loss mask to the text positions. - if has_labels: + # Put text labels and loss mask to the text positions. final_labels[label_batch_indices, label_text_position_ids] = labels[ label_batch_indices, label_non_image_indices ] + final_loss_mask[batch_indices, text_position_ids] = loss_mask[ batch_indices, non_image_indices ] - with torch.no_grad(): - # Create a mask for the image embedding positions. - images_mask = torch.full( - (batch_size, max_seq_len), True, dtype=torch.bool, device=input_ids.device - ) - images_mask[batch_indices, text_position_ids] = ( - False # No images in the text positions. - ) - # Samples can have different amount of images tokens. new_position_ids[:, -1] gives the last text position id for each sample. - # Padding is needed when the number of image tokens differs. Compute the number of padding tokens on the right for each sample. - padding = max_seq_len - 1 - new_position_ids[:, -1] - # Mark the padding tokens on the right as False in the images mask. -1 adjusts cumulative sum to be zero-based. - images_mask &= images_mask.cumsum(dim=-1) - 1 >= padding[:, None] - - if self.pre_process: - final_embedding[images_mask] = image_embeddings.reshape(-1, embed_dim).contiguous() + # For labels, we need to pick the last label index that got dropped by the shift to left. + label_extra_text_position_ids = seq_lens - 1 + batch_range = torch.arange(len(label_extra_text_position_ids)) + final_labels[batch_range, label_extra_text_position_ids] = labels[batch_range, -1] - if has_labels: # Loss mask the image positions. final_loss_mask[images_mask] = 0 # Loss mask last text position just before an image so that text token does not need to predict the first image token. batch_image_indices, image_indices = torch.where(image_token_mask) - text_before_image_indices = torch.maximum(image_indices - 1, torch.tensor(0)) - final_loss_mask[batch_image_indices, text_before_image_indices] = 0 + # Indices just before image tokens. If it's -1, skip it. + before_image_indices = image_indices - 1 + valid = before_image_indices >= 0 + valid_batch_image_indices = batch_image_indices[valid] + valid_before_image_indices = before_image_indices[valid] + # Map those indices those position ids. + valid_before_image_indices = new_position_ids[ + valid_batch_image_indices, valid_before_image_indices + ] + + final_loss_mask[valid_batch_image_indices, valid_before_image_indices] = 0 if final_embedding is not None and has_labels: assert ( @@ -367,21 +385,23 @@ def forward( input_ids: torch.Tensor, position_ids: torch.Tensor, attention_mask: torch.Tensor, - labels: torch.Tensor = None, - loss_mask: torch.Tensor = None, - inference_params: InferenceParams = None, - image_token_index: int = IMAGE_TOKEN_INDEX, + labels: Optional[torch.Tensor] = None, + loss_mask: Optional[torch.Tensor] = None, + inference_params: Optional[InferenceParams] = None, + num_image_tiles: Optional[List[int]] = None, + image_token_index: Optional[int] = IMAGE_TOKEN_INDEX, ) -> torch.Tensor: """Forward function of the LLaVA model. Args: - images (torch.Tensor): input image of shape [batch, img_h, img_w]. + images (torch.Tensor): input image of shape [num_tiles, img_h, img_w]. num_tiles means the number of image tiles in this batch. input_ids (torch.Tensor): input text ids [batch, text_seq_len]. position_ids (torch.Tensor): input text position ids [batch, text_seq_len]. attention_mask (torch.Tensor): Attention mask for the language model [batch, 1, combined_seq_len, combined_seq_len]. labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. loss_mask (torch.Tensor): Text loss mask [batch, text_seq_len]. inference_params (InferenceParams): Inference-time parameters including KV cache. + num_image_tiles (list of int): Number of tiles per image. Default None assumes 1 tile per image. image_token_index (int): ID for input images. Returns: @@ -396,24 +416,25 @@ def forward( if use_inference_kv_cache: image_embeddings = None elif self.add_encoder: - image_embeddings = self.vision_model(images) # [b, img_seq_len, h_vision] + image_embeddings = self.vision_model(images) # [num_tiles, img_seq_len, h_vision] if self._drop_vision_class_token: image_embeddings = image_embeddings[:, self.vision_model.class_token_len :, :] # contiguous() call required as `permute` can sparsify the tensor and this breaks pipelining image_embeddings = image_embeddings.permute( 1, 0, 2 - ).contiguous() # [img_seq_len, b, h_vision] + ).contiguous() # [img_seq_len, num_tiles, h_vision] # map vision model output size to language model input size. image_embeddings = self.vision_projection( image_embeddings - ) # [img_seq_len, b, h_vision] + ) # [img_seq_len, num_tiles, h_language] + # TODO: Support batched inference. # If running inference, the language model KV cache will be updated for image token positions. # Here we store the image tokens sequence length, which can be used as an offset to the KV cache later. if inference_params is not None: inference_params.key_value_memory_dict["image_tokens_count"] = ( - image_embeddings.shape[0] + image_embeddings.shape[0] * image_embeddings.shape[1] ) else: image_embeddings = self.encoder_hidden_state @@ -434,6 +455,10 @@ def forward( 1, 0 ).contiguous() # [b, text_seq_len, h_language] + # Assume 1 tile per image if the number of tiles is not provided. + if num_image_tiles is None: + num_image_tiles = torch.ones(images.shape[0], dtype=torch.int, device=input_ids.device) + # Preprocess input, labels and loss mask. combined_embeddings, new_labels, new_loss_mask = self._preprocess_data( image_embeddings, @@ -443,6 +468,7 @@ def forward( labels, use_inference_kv_cache, image_token_index, + num_image_tiles, ) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len] output = self.language_model( diff --git a/tests/unit_tests/models/test_llava_model.py b/tests/unit_tests/models/test_llava_model.py index d503f6783b..cb035b864d 100644 --- a/tests/unit_tests/models/test_llava_model.py +++ b/tests/unit_tests/models/test_llava_model.py @@ -19,17 +19,17 @@ def setup_method(self, method): model_parallel_cuda_manual_seed(123) language_config = TransformerConfig( - num_layers=3, hidden_size=128, num_attention_heads=8, use_cpu_initialization=True + num_layers=3, hidden_size=128, num_attention_heads=8, use_cpu_initialization=False ) vision_config = TransformerConfig( - num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=True + num_layers=2, hidden_size=64, num_attention_heads=4, use_cpu_initialization=False ) vision_projection_config = TransformerConfig( num_layers=2, hidden_size=128, ffn_hidden_size=72, num_attention_heads=1, - use_cpu_initialization=True, + use_cpu_initialization=False, ) language_layer_spec = get_gpt_layer_with_transformer_engine_spec() @@ -74,27 +74,35 @@ def test_preprocess_data(self): self.model.cuda() image_embedding_value = torch.tensor(123.0) - image_embeddings = image_embedding_value * torch.ones((577, 3, 128)).cuda() + # 3 images with 1 tile and 2 image with 2 tiles = 7 tiles. + image_embeddings = image_embedding_value * torch.ones((577, 7, 128)).cuda() image_token_index = -200 - input_ids = torch.arange(0, 1024, dtype=torch.int).expand(4, 1024).cuda() + input_ids = torch.arange(0, 1024, dtype=torch.int).expand(5, 1024).cuda() input_ids[0, 0] = image_token_index # image before text input_ids[1, 100] = image_token_index # image in between input_ids[2, -1] = image_token_index # image at the end # input_ids[3] - no image + input_ids[4, 50] = image_token_index # two images in between + input_ids[4, 150] = image_token_index language_embedding_value = torch.tensor(999.0) - language_embeddings = language_embedding_value * torch.ones((4, 1024, 128)).cuda() + language_embeddings = language_embedding_value * torch.ones((5, 1024, 128)).cuda() # Labels are input_ids shifted to left by one. - labels = torch.arange(1, 1025, dtype=torch.int).expand(4, 1024).cuda() + labels = torch.arange(1, 1025, dtype=torch.int).expand(5, 1024).cuda() labels[1, 99] = image_token_index labels[2, -2] = image_token_index + labels[4, 49] = image_token_index + labels[4, 149] = image_token_index - loss_mask = torch.ones((4, 1024), dtype=torch.int).cuda() + loss_mask = torch.ones((5, 1024), dtype=torch.float).cuda() # Mask some text inputs (the text mask should carry over) - loss_mask[:2, :10] = 0 - loss_mask[:2, 110:120] = 0 + loss_mask[:2, :10] = 0.0 + loss_mask[:2, 110:120] = 0.0 + + # Number of tiles for each image in the batch. + num_image_tiles = torch.tensor([1, 2, 1, 2, 1], dtype=torch.int).cuda() use_inference_kv_cache = False @@ -106,134 +114,192 @@ def test_preprocess_data(self): labels, use_inference_kv_cache, image_token_index, + num_image_tiles, ) - assert embeddings.shape == torch.Size((1600, 4, 128)) - assert labels.shape == torch.Size((4, 1600)) + img_seq_len = 577 + # The fifth sample has 2 images with 3 tiles and 1024 text tokens. + max_seq_len = 3 * img_seq_len - 2 + 1024 + + assert embeddings.shape == torch.Size((max_seq_len, 5, 128)) + assert labels.shape == torch.Size((5, max_seq_len)) assert loss_mask.shape == labels.shape # First sample where image is before text (index 0). - expected_embeddings = torch.empty(1600).cuda() + expected_embeddings = torch.empty(max_seq_len).cuda() expected_embeddings[:577] = image_embedding_value - expected_embeddings[577:] = language_embedding_value + expected_embeddings[577:1600] = language_embedding_value + expected_embeddings[1600:] = 0 # padding - expected_labels = torch.empty(1600, dtype=torch.int).cuda() - expected_labels[:576] = -100 - expected_labels[576:] = torch.arange(1, 1025, dtype=torch.int) + expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda() + expected_labels[:576] = -100 # image + expected_labels[576:1600] = torch.arange(1, 1025, dtype=torch.int) + expected_labels[1600:] = -100 # padding - expected_loss_mask = torch.empty(1600, dtype=torch.int).cuda() + expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda() expected_loss_mask[:577] = 0 expected_loss_mask[577:586] = 0 expected_loss_mask[586:686] = 1 expected_loss_mask[686:696] = 0 - expected_loss_mask[696:] = 1 + expected_loss_mask[696:1600] = 1 + expected_loss_mask[1600:] = 0 assert torch.allclose(embeddings[:, 0], expected_embeddings.unsqueeze(1)) assert torch.allclose(labels[0], expected_labels) assert torch.allclose(loss_mask[0], expected_loss_mask) - # Second sample where image is in between (index 100). - expected_embeddings = torch.empty(1600).cuda() + # Second sample where image is in between (index 100). The image has 2 tiles. + expected_embeddings = torch.empty(max_seq_len).cuda() expected_embeddings[:100] = language_embedding_value - expected_embeddings[100:677] = image_embedding_value - expected_embeddings[677:] = language_embedding_value + expected_embeddings[100:1254] = image_embedding_value + expected_embeddings[1254:2177] = language_embedding_value + expected_embeddings[2177:] = 0 # padding - expected_labels = torch.empty(1600, dtype=torch.int).cuda() + expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda() expected_labels[:99] = torch.arange(1, 100) - expected_labels[99:676] = -100 - expected_labels[676:] = torch.arange(101, 1025) + expected_labels[99:1253] = -100 # image + expected_labels[1253:2177] = torch.arange(101, 1025) + expected_labels[2177:] = -100 # padding - expected_loss_mask = torch.empty(1600, dtype=torch.int).cuda() + expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda() expected_loss_mask[:10] = 0 expected_loss_mask[10:99] = 1 - expected_loss_mask[99] = ( - 0 # Last text position before the image is not required to predict the first image embedding. - ) - expected_loss_mask[100:677] = 0 - expected_loss_mask[677:686] = 1 - expected_loss_mask[686:696] = 0 - expected_loss_mask[696:] = 1 + # Last text position before the image is not required to predict the first image embedding. + expected_loss_mask[99] = 0 + expected_loss_mask[100:1254] = 0 + expected_loss_mask[1254:1263] = 1 + expected_loss_mask[1263:1273] = 0 + expected_loss_mask[1273:2177] = 1 + expected_loss_mask[2177:] = 0 # padding assert torch.allclose(embeddings[:, 1], expected_embeddings.unsqueeze(1)) assert torch.allclose(labels[1], expected_labels) assert torch.allclose(loss_mask[1], expected_loss_mask) # Third sample where image is at the end. - expected_embeddings = torch.empty(1600).cuda() + expected_embeddings = torch.empty(max_seq_len).cuda() expected_embeddings[:1023] = language_embedding_value - expected_embeddings[1023:] = image_embedding_value + expected_embeddings[1023:1600] = image_embedding_value + expected_embeddings[1600:] = 0 # padding - expected_labels = torch.empty(1600, dtype=torch.int).cuda() + expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda() expected_labels[:1022] = torch.arange(1, 1023) expected_labels[1022:1599] = -100 expected_labels[1599] = 1024 + expected_labels[1600:] = -100 # padding - expected_loss_mask = torch.empty(1600, dtype=torch.int).cuda() + expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda() expected_loss_mask[:1022] = 1 - expected_loss_mask[1022] = ( - 0 # Last text position before the image is not required to predict the first image embedding. - ) - expected_loss_mask[1023:] = 0 + # Last text position before the image is not required to predict the first image embedding. + expected_loss_mask[1022] = 0 + expected_loss_mask[1023:1600] = 0 + expected_loss_mask[1600:] = 0 # padding assert torch.allclose(embeddings[:, 2], expected_embeddings.unsqueeze(1)) assert torch.allclose(labels[2], expected_labels) assert torch.allclose(loss_mask[2], expected_loss_mask) # Fourth sample where there is no image. - expected_embeddings = torch.empty(1600).cuda() + expected_embeddings = torch.empty(max_seq_len).cuda() expected_embeddings[:1024] = language_embedding_value expected_embeddings[1024:] = 0 # padding - expected_labels = torch.empty(1600, dtype=torch.int).cuda() + expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda() expected_labels[:1024] = torch.arange(1, 1025) - expected_labels[1024:] = -100 + expected_labels[1024:] = -100 # padding - expected_loss_mask = torch.empty(1600, dtype=torch.int).cuda() + expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda() expected_loss_mask[:1024] = 1 - expected_loss_mask[1024:] = 0 + expected_loss_mask[1024:] = 0 # padding assert torch.allclose(embeddings[:, 3], expected_embeddings.unsqueeze(1)) assert torch.allclose(labels[3], expected_labels) assert torch.allclose(loss_mask[3], expected_loss_mask) + # Fifth sample has two images in between. The first image has two tiles. + expected_embeddings = torch.empty(max_seq_len).cuda() + expected_embeddings[:50] = language_embedding_value + expected_embeddings[50:1204] = image_embedding_value # two tiles + expected_embeddings[1204:1303] = language_embedding_value + expected_embeddings[1303:1880] = image_embedding_value + expected_embeddings[1880:] = language_embedding_value + + expected_labels = torch.empty(max_seq_len, dtype=torch.int).cuda() + expected_labels[:49] = torch.arange(1, 50) + expected_labels[49:1203] = -100 # image + expected_labels[1203:1302] = torch.arange(51, 150) + expected_labels[1302:1879] = -100 # image + expected_labels[1879:] = torch.arange(151, 1025) + + expected_loss_mask = torch.empty(max_seq_len, dtype=torch.float).cuda() + expected_loss_mask[:49] = 1 + expected_loss_mask[49:1204] = 0 + expected_loss_mask[1204:1302] = 1 + expected_loss_mask[1302:1880] = 0 + expected_loss_mask[1880:] = 1 + + assert torch.allclose(embeddings[:, 4], expected_embeddings.unsqueeze(1)) + assert torch.allclose(labels[4], expected_labels) + assert torch.allclose(loss_mask[4], expected_loss_mask) + @pytest.mark.internal def test_forward(self): self.model.cuda() - img = torch.randn((3, 3, 336, 336)).cuda() + # 3 images with 1 tile and 2 images with 2 tiles. + img = torch.randn((7, 3, 336, 336)).cuda() image_token_index = -200 - input_ids = torch.randint(0, 2048, (4, 1024)).cuda() + input_ids = torch.randint(0, 2048, (5, 1024)).cuda() input_ids[0, 0] = image_token_index # image before text input_ids[1, 100] = image_token_index # image in between input_ids[2, -1] = image_token_index # image at the end # input_ids[3] - no image + input_ids[4, 50] = image_token_index + input_ids[4, 150] = image_token_index - position_ids = torch.arange(0, 1024, dtype=torch.int).expand(4, 1024).cuda() + position_ids = torch.arange(0, 1024, dtype=torch.int).expand(5, 1024).cuda() - loss_mask = torch.ones((4, 1024)).cuda() + loss_mask = torch.ones((5, 1024)).cuda() attention_mask = None # Causal. - labels = torch.randint(0, 2048, (4, 1024)).cuda() + labels = torch.randint(0, 2048, (5, 1024)).cuda() labels[1, 99] = image_token_index labels[2, -2] = image_token_index + num_image_tiles = torch.tensor([1, 2, 1, 2, 1], dtype=torch.int).cuda() + # Try with labels. loss, new_loss_mask = self.model.forward( - img, input_ids, position_ids, attention_mask, labels, loss_mask + img, + input_ids, + position_ids, + attention_mask, + labels, + loss_mask, + num_image_tiles=num_image_tiles, ) - # The final sequence length 1600 comes from 577 image tokens and 1023 text tokens. - assert loss.shape == new_loss_mask.shape == torch.Size((4, 1600)) + + # The maximum sequence length is given by the sample with 2 images in 3 tiles, minus two image token indices, plus other text tokens. + img_seq_len = 577 + max_seq_len = img_seq_len * 3 - 2 + 1024 + assert loss.shape == new_loss_mask.shape == torch.Size((5, max_seq_len)) # Try without labels and without inference params. logits = self.model.forward( - img, input_ids, position_ids, attention_mask, labels=None, loss_mask=None + img, + input_ids, + position_ids, + attention_mask, + labels=None, + loss_mask=None, + num_image_tiles=num_image_tiles, ) - assert logits.shape == torch.Size((4, 1600, 2048)) + assert logits.shape == torch.Size((5, max_seq_len, 2048)) # Try without labels and with inference params. - inference_params = InferenceParams(4, 1600) + inference_params = InferenceParams(5, max_seq_len) logits = self.model.forward( img, input_ids, @@ -241,18 +307,19 @@ def test_forward(self): attention_mask, labels=None, loss_mask=None, + num_image_tiles=num_image_tiles, inference_params=inference_params, ) - assert logits.shape == torch.Size((4, 1600, 2048)) + assert logits.shape == torch.Size((5, max_seq_len, 2048)) # Check KV cache got populated correctly. kv_dict = inference_params.key_value_memory_dict - assert kv_dict["image_tokens_count"] == 577 + assert kv_dict["image_tokens_count"] == 577 * 7 for layer_no in range(1, 4): # 3 layers in the model. layer_kv = kv_dict[layer_no] # Expected shape is [sequence_len, batch_size, num_heads, hidden_size_per_head] - assert layer_kv[0].shape == layer_kv[1].shape == torch.Size((1600, 4, 8, 16)) + assert layer_kv[0].shape == layer_kv[1].shape == torch.Size((max_seq_len, 5, 8, 16)) @pytest.mark.internal def test_save_load(self, tmp_path):