From de18820cdf37341b25ec73701421d2289c336257 Mon Sep 17 00:00:00 2001 From: Tuomas Rintamaki Date: Fri, 13 Dec 2024 02:46:54 -0800 Subject: [PATCH] ADLR/megatron-lm!2441 - Llava pp > 1 fix --- examples/multimodal/train.py | 4 ++-- megatron/core/models/multimodal/llava_model.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/multimodal/train.py b/examples/multimodal/train.py index 5ff2121b3d..1dc68d1173 100644 --- a/examples/multimodal/train.py +++ b/examples/multimodal/train.py @@ -48,7 +48,7 @@ def get_batch(data_iterator): pp_size = get_pipeline_model_parallel_world_size() if not is_first_or_last_stage(pp_size, args.encoder_pipeline_model_parallel_size): # Note these are all set to None above. - return tokens, labels, loss_mask, attention_mask, position_ids, imgs, num_tiles + return tokens, labels, loss_mask, attention_mask, position_ids, imgs, num_tiles, packed_seq_params # Broadcast data. torch.cuda.nvtx.range_push("get_data") @@ -66,7 +66,7 @@ def get_batch(data_iterator): cu_lengths = tensor_parallel.broadcast_data(["cu_lengths"], data, torch.int32)["cu_lengths"] max_lengths = tensor_parallel.broadcast_data(["max_lengths"], data, torch.int32)["max_lengths"] - # Dummy image, no image. + # No image input (text-only sample) if the dataloader produced a dummy image. if imgs.shape == torch.Size([1, 1]): # FIXME: text-only data can cause a hang if the vision model is own its own pipeline rank and --freeze-ViT is enabled. imgs = torch.tensor([], dtype=torch.float32, device=data_text.device) diff --git a/megatron/core/models/multimodal/llava_model.py b/megatron/core/models/multimodal/llava_model.py index dafe377456..9c8dcaf97c 100644 --- a/megatron/core/models/multimodal/llava_model.py +++ b/megatron/core/models/multimodal/llava_model.py @@ -828,7 +828,7 @@ def forward( ).contiguous() # [b, text_seq_len, h_language] # Assume 1 tile per image if the number of tiles is not provided. - if num_image_tiles is None: + if num_image_tiles is None and images is not None: num_image_tiles = torch.ones(images.shape[0], dtype=torch.int, device=input_ids.device) combined_embeddings, new_labels, new_loss_mask = self._preprocess_data(