Skip to content

Commit

Permalink
Merge branch 'pmannan/llava_debug' into 'main'
Browse files Browse the repository at this point in the history
LLaVA Multimodal SP support

See merge request ADLR/megatron-lm!2038
  • Loading branch information
ericharper committed Oct 19, 2024
2 parents db6cb4e + 2c950a5 commit 739177e
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 24 deletions.
108 changes: 91 additions & 17 deletions megatron/core/models/multimodal/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@

import torch

from megatron.core import InferenceParams
from megatron.core import InferenceParams, tensor_parallel
from megatron.core.config_logger import has_config_logger_enabled, log_config_to_disk
from megatron.core.extensions.transformer_engine import TEDotProductAttention
from megatron.core.models.gpt import GPTModel
from megatron.core.models.vision.clip_vit_model import CLIPViTModel, get_num_image_embeddings
from megatron.core.models.vision.multimodal_projector import MultimodalProjector
from megatron.core.parallel_state import get_tensor_model_parallel_world_size
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import is_te_min_version

IMAGE_TOKEN_INDEX = -200 # ID for images in the input sequence.
IGNORE_INDEX = -100 # ID for labels that should be ignored.
Expand Down Expand Up @@ -98,6 +101,14 @@ def __init__(
self.vision_projection = None
self.language_model = None

self.sequence_parallel_lm = language_transformer_config.sequence_parallel
if self.sequence_parallel_lm:
assert (
language_transformer_layer_spec.submodules.self_attention.submodules.core_attention
== TEDotProductAttention
), "Sequence Parallelism is supported only with Transformer Engine DotProductAttention."
self.tp_comm_overlap_lm = language_transformer_config.tp_comm_overlap

# This attribute is needed to check if an all-reduce is required
# on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`.
self.share_embeddings_and_output_weights = False
Expand Down Expand Up @@ -232,6 +243,7 @@ def _preprocess_data(
use_inference_kv_cache,
image_token_index,
num_image_tiles,
attention_mask,
):
"""Preprocess input data before input to language model.
Expand Down Expand Up @@ -273,11 +285,11 @@ def _preprocess_data(
# No pre- or postprocessing needed.
# With pipeline parallel > 2, this means a chunk in the middle of the model.
if not self.pre_process and not self.post_process:
return language_embeddings, loss_mask, labels
return language_embeddings, loss_mask, labels, attention_mask

# If using the inference KV cache, the image tokens are already computed.
if use_inference_kv_cache:
return language_embeddings, loss_mask, labels
return language_embeddings, loss_mask, labels, attention_mask

img_seq_len = self._img_seq_len
batch_size, text_seq_len = input_ids.shape
Expand Down Expand Up @@ -311,6 +323,20 @@ def _preprocess_data(
):
max_seq_len = self._language_max_sequence_length

if self.sequence_parallel_lm:
if self.tp_comm_overlap_lm:
# If shorter: Pad to language_max_sequence_length to use TP Comm overlap.
# If longer: Gets truncated later.
if max_seq_len < self._language_max_sequence_length:
padded_seq_len = self._language_max_sequence_length
else:
# Pad to multiple of tp size for sequence parallelism
tp_world_size = get_tensor_model_parallel_world_size()
padded_seq_len = int(
(max_seq_len + (tp_world_size - 1)) // tp_world_size * tp_world_size
)
sp_padding_needed = padded_seq_len - max_seq_len
max_seq_len = padded_seq_len
batch_indices, non_image_indices = torch.where(input_ids != image_token_index)

# New position ids for the text tokens, shifted by the image sequence length.
Expand Down Expand Up @@ -420,23 +446,44 @@ def _preprocess_data(
final_embedding.shape[:2] == final_labels.shape == final_loss_mask.shape
), "unexpected shapes after data preprocessing"

if final_embedding is not None:
final_embedding = final_embedding.transpose(1, 0).contiguous()

# Truncate if exceeding the language model's max sequence length.
truncate_embedding = (
final_embedding is not None
and final_embedding.shape[0] > self._language_max_sequence_length
)
if truncate_embedding:
final_embedding = final_embedding[: self._language_max_sequence_length]

truncate_labels = has_labels and final_labels.shape[1] > self._language_max_sequence_length
if truncate_labels:
final_labels = final_labels[:, : self._language_max_sequence_length]
final_loss_mask = final_loss_mask[:, : self._language_max_sequence_length]

return final_embedding, final_labels, final_loss_mask
if final_embedding is not None:
final_embedding = final_embedding.transpose(1, 0).contiguous()
# Truncate if exceeding the language model's max sequence length.
if final_embedding.shape[0] > self._language_max_sequence_length:
final_embedding = final_embedding[: self._language_max_sequence_length]
if self.sequence_parallel_lm:
# Create an attention mask. This ensures correct computation.
# This is done even when no padding was done as we set mask_type to
# 'padding' or 'padding_causal' when using SP.
if attention_mask is None:
# Create base attention mask with original seq len to indicate valid tokens
attention_mask = (
torch.ones(
(
final_embedding.shape[1],
final_embedding.shape[0] - sp_padding_needed,
),
device=final_embedding.device,
)
.unsqueeze(1)
.unsqueeze(1)
) # [b, 1, 1, final seq len - sp_padding_needed]
if sp_padding_needed > 0:
# Add the padding portion of the mask
attention_mask = torch.nn.functional.pad(attention_mask, (0, sp_padding_needed))
if is_te_min_version("1.7.0"):
# Attention mask True/False meaning flipped in 1.7.0
attention_mask = attention_mask < 0.5
final_embedding = tensor_parallel.scatter_to_sequence_parallel_region(
final_embedding
)

return final_embedding, final_labels, final_loss_mask, attention_mask

def forward(
self,
Expand All @@ -460,7 +507,7 @@ def forward(
input_ids (torch.Tensor): input text ids [batch, text_seq_len].
position_ids (torch.Tensor): input text position ids [batch, text_seq_len].
attention_mask (torch.Tensor): Language model attention mask
[batch, 1, combined_seq_len, combined_seq_len].
[batch, 1, 1, combined_seq_len].
labels (torch.Tensor): Optional target text labels [batch, combined_seq_len].
loss_mask (torch.Tensor): Text loss mask [batch, text_seq_len].
inference_params (InferenceParams): Inference-time parameters including KV cache.
Expand Down Expand Up @@ -523,9 +570,35 @@ def forward(
# Note: This adds absolute position embedding but not RoPE.
# Each image is counted as one position.
# RoPE is added in language_model forward. Each image embedding is one position.
if self.sequence_parallel_lm:
# Pad to nearest multiple of TP world size for embedding.
tp_world_size = get_tensor_model_parallel_world_size()
padded_seq_len = (
int(
(input_ids_text.shape[1] + tp_world_size - 1)
// tp_world_size
* tp_world_size
)
- input_ids_text.shape[1]
)
if padded_seq_len != 0:
input_ids_text = torch.nn.functional.pad(input_ids_text, (0, padded_seq_len))
if position_ids is not None:
position_ids = torch.nn.functional.pad(position_ids, (0, padded_seq_len))
language_embeddings = self.language_model.embedding(
input_ids=input_ids_text, position_ids=position_ids
) # [text_seq_len, b, h_language]
if self.sequence_parallel_lm:
# Gather the language embeddings back.
# We use the full embedding to insert image embeddings
# and then scatter to avoid load imbalance.
language_embeddings = tensor_parallel.gather_from_sequence_parallel_region(
language_embeddings, tensor_parallel_output_grad=False
)
# Remove the padding done for SP as we'll need new padding calculation
# after image embeddings are inserted.
if padded_seq_len != 0:
language_embeddings = language_embeddings[:-padded_seq_len]
language_embeddings = language_embeddings.transpose(
1, 0
).contiguous() # [b, text_seq_len, h_language]
Expand All @@ -535,7 +608,7 @@ def forward(
num_image_tiles = torch.ones(images.shape[0], dtype=torch.int, device=input_ids.device)

# Preprocess input, labels and loss mask.
combined_embeddings, new_labels, new_loss_mask = self._preprocess_data(
combined_embeddings, new_labels, new_loss_mask, attention_mask = self._preprocess_data(
image_embeddings,
language_embeddings,
input_ids,
Expand All @@ -544,6 +617,7 @@ def forward(
use_inference_kv_cache,
image_token_index,
num_image_tiles,
attention_mask,
) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len]

output = self.language_model(
Expand Down
6 changes: 5 additions & 1 deletion megatron/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def finish_mpu_init():
_compile_dependencies()

if args.tp_comm_overlap:
#TODO: Should this be activated with just decoder-tp-comm-overlap too?
_initialize_tp_communicators()

# No continuation function
Expand Down Expand Up @@ -211,7 +212,10 @@ def _initialize_tp_communicators():
else:
ub_cfgs = {}

input_shape = [(args.seq_length * args.micro_batch_size) // args.context_parallel_size , args.hidden_size]
if getattr(args, 'decoder_tp_comm_overlap', False):
input_shape = [(args.decoder_seq_length * args.micro_batch_size) // args.context_parallel_size , args.hidden_size]
else:
input_shape = [(args.seq_length * args.micro_batch_size) // args.context_parallel_size , args.hidden_size]

if is_te_min_version("1.9.0"):
# The process group with the target bootstrap backend is created in Transformer Engine.
Expand Down
60 changes: 55 additions & 5 deletions pretrain_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from megatron.core.datasets.multimodal_dataset import MockMultimodalDataset, MultimodalDatasetConfig
from megatron.core.enums import ModelType
from megatron.core.models.vision.clip_vit_model import get_num_image_embeddings
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.models.multimodal.llava_model import LLaVAModel, IMAGE_TOKEN_INDEX
from megatron.core.models.multimodal.llava_spec import (
decoder_model_with_transformer_engine_default_spec,
Expand Down Expand Up @@ -53,20 +54,47 @@ def model_provider(
)

old_seq_length = args.seq_length
# dataloader-seq-length is required to determine the length of text seq len
if args.dataloader_seq_length is None:
args.dataloader_seq_length = args.seq_length

# decoder_seq_length denotes the language model sequence length.
args.decoder_seq_length = args.seq_length + num_image_embeddings
decoder_seq_len = args.seq_length + num_image_embeddings

# seq_length and encoder_seq_length denote the vision model sequence length. Override if the user provided something else.
args.seq_length = args.encoder_seq_length = num_image_embeddings
if torch.distributed.get_rank() == 0 and old_seq_length != args.seq_length:
warnings.warn(
f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})"
)
#Padding to multiple of 64 when using sequence parallel
sp_padding_needed = 0
tp_size = args.tensor_model_parallel_size
if args.sequence_parallel:
assert args.transformer_impl == "transformer_engine", \
"TransformerEngine is needed to support Sequence Parallelism implementation"
if not args.decoder_tp_comm_overlap:
args.decoder_seq_length = decoder_seq_len
sp_padding_needed = int((args.decoder_seq_length + (tp_size-1)) // tp_size * tp_size) - args.decoder_seq_length
if sp_padding_needed > 0:
args.decoder_seq_length += sp_padding_needed
else:
# If TP Comm Overlap is enabled for LM backbone,
# user needs to provide decoder_seq_length with any potential padding needed
assert args.decoder_seq_length is not None, \
"Please provide --decoder-seq-length when using TP Comm overlap for LM backbone"
sp_padding_needed = args.decoder_seq_length - decoder_seq_len
else:
args.decoder_seq_length = decoder_seq_len

args.max_position_embeddings = max(args.max_position_embeddings, args.decoder_seq_length)

print_rank_0('building a multimodal model ...')
language_transformer_config = core_transformer_config_from_args(get_args())
if args.decoder_tp_comm_overlap:
assert args.transformer_impl == "transformer_engine", \
"TransformerEngine is needed to support Decoder TP Comm overlap"
language_transformer_config.tp_comm_overlap = args.decoder_tp_comm_overlap

if args.spec is not None:
language_transformer_layer_spec = import_module(args.spec)
Expand All @@ -78,7 +106,13 @@ def model_provider(
language_transformer_layer_spec = decoder_model_with_local_default_spec(
args.num_experts, args.moe_grouped_gemm
)


if sp_padding_needed > 0:
if language_transformer_layer_spec.submodules.self_attention.params.get('attn_mask_type', '') == AttnMaskType.causal:
language_transformer_layer_spec.submodules.self_attention.params['attn_mask_type'] = AttnMaskType.padding_causal
elif language_transformer_layer_spec.submodules.self_attention.params.get('attn_mask_type', '') == AttnMaskType.no_mask:
language_transformer_layer_spec.submodules.self_attention.params['attn_mask_type'] = AttnMaskType.padding

if args.transformer_impl == "transformer_engine":
vision_transformer_layer_spec = get_vit_layer_with_transformer_engine_spec()
else: # transformer_impl == "local"
Expand All @@ -90,9 +124,21 @@ def model_provider(
vision_transformer_config.first_pipeline_num_layers = None
vision_transformer_config.last_pipeline_num_layers = None
vision_transformer_config.vision_model_type = vision_model_type

if vision_transformer_config.sequence_parallel:
print_rank_0("> Disabling Sequence parallelism in Vision Transformer. Not yet supported")
vision_transformer_config.sequence_parallel = False
if vision_transformer_config.tp_comm_overlap:
print_rank_0("> Disabling TP Comm overlap in Vision Transformer. Not yet supported")
vision_transformer_config.tp_comm_overlap = False

vision_projection_type = "mlp"
vision_projection_config = deepcopy(language_transformer_config)
if vision_projection_config.sequence_parallel:
print_rank_0("> Disabling Sequence parallelism in Vision Projection. Not yet supported")
vision_projection_config.sequence_parallel = False
if vision_projection_config.tp_comm_overlap:
print_rank_0("> Disabling TP Comm overlap in Vision Projection. Not yet supported")
vision_projection_config.tp_comm_overlap = False

if args.encoder_pipeline_model_parallel_size > 0:
assert (
Expand Down Expand Up @@ -121,7 +167,7 @@ def model_provider(
language_transformer_config=language_transformer_config,
language_transformer_layer_spec=language_transformer_layer_spec,
language_vocab_size=args.padded_vocab_size,
language_max_sequence_length=args.max_position_embeddings,
language_max_sequence_length=args.decoder_seq_length,
vision_transformer_config=vision_transformer_config,
vision_transformer_layer_spec=vision_transformer_layer_spec,
drop_vision_class_token=args.disable_vision_class_token,
Expand Down Expand Up @@ -164,7 +210,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
config = MultimodalDatasetConfig(
random_seed=args.seed,
split=args.split,
sequence_length=args.decoder_seq_length - args.seq_length,
sequence_length=args.dataloader_seq_length,
tokenizer=get_tokenizer(),
reset_position_ids=args.reset_position_ids,
reset_attention_mask=args.reset_attention_mask,
Expand Down Expand Up @@ -292,6 +338,10 @@ def add_vlm_extra_args(parser):
default=False,
help="Drop vision model class token",
)
group.add_argument("--dataloader-seq-length", type=int, help="Make dataloader to produce sequences of specific length.")
group.add_argument("--decoder-tp-comm-overlap", action="store_true", default=False, help="Enables the overlap of "
"Tensor parallel communication and GEMM kernels in Decoder only. "
"Please provide decoder-seq-length when using this feature.")
return parser


Expand Down
4 changes: 3 additions & 1 deletion tests/unit_tests/models/test_llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,9 @@ def test_preprocess_data(self):
num_image_tiles = torch.tensor([1, 2, 1, 2, 1], dtype=torch.int).cuda()

use_inference_kv_cache = False
attention_mask = None

embeddings, labels, loss_mask = self.model._preprocess_data(
embeddings, labels, loss_mask, attention_mask = self.model._preprocess_data(
image_embeddings,
language_embeddings,
input_ids,
Expand All @@ -134,6 +135,7 @@ def test_preprocess_data(self):
use_inference_kv_cache,
image_token_index,
num_image_tiles,
attention_mask,
)

img_seq_len = 577
Expand Down

0 comments on commit 739177e

Please sign in to comment.