From 40fb590e4bb4aa01053f1c09d6d5f58992f8cf53 Mon Sep 17 00:00:00 2001 From: Xiaowei Ren Date: Tue, 10 Dec 2024 16:44:06 -0800 Subject: [PATCH] ADLR/megatron-lm!2404 - move get_batch_on_this_cp_rank to mcore utils --- .../core/models/multimodal/llava_model.py | 4 +- megatron/core/utils.py | 38 ++++++++++++++++++ megatron/training/utils.py | 39 +++---------------- 3 files changed, 44 insertions(+), 37 deletions(-) diff --git a/megatron/core/models/multimodal/llava_model.py b/megatron/core/models/multimodal/llava_model.py index 576cb2acc6..5e3e357e84 100644 --- a/megatron/core/models/multimodal/llava_model.py +++ b/megatron/core/models/multimodal/llava_model.py @@ -16,7 +16,7 @@ from megatron.core.transformer import MegatronModule from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig -from megatron.core.utils import log_single_rank +from megatron.core.utils import get_batch_on_this_cp_rank, log_single_rank try: import transformer_engine # pylint: disable=unused-import @@ -636,8 +636,6 @@ def _process_embedding_token_parallel( if self.context_parallel_lm > 1: # Distribute sequence across CP ranks - from megatron.training.utils import get_batch_on_this_cp_rank - batch = get_batch_on_this_cp_rank( { "combined_embeddings": combined_embeddings, diff --git a/megatron/core/utils.py b/megatron/core/utils.py index 6b46f292d5..3bb28042b8 100644 --- a/megatron/core/utils.py +++ b/megatron/core/utils.py @@ -1413,3 +1413,41 @@ def __exit__( def is_float8tensor(tensor: torch.Tensor) -> bool: """Check if a tensor is a Transformer Engine Float8Tensor""" return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor) + + +######################## +### context parallel ### +######################## + + +def get_batch_on_this_cp_rank(batch: Dict[str, Any]): + """Slice batch input along sequence dimension into multiple chunks, + which are parallelized across GPUs in a context parallel group. + """ + + # With causal masking, each token only attends to its prior tokens. Simply split + # sequence into CP chunks can result in severe load imbalance. That's to say, chunks + # at the end of sequence have bigger workload than others. To address this issue, + # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0 + # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so + # that we can get balanced workload among GPUs in a context parallel group. + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: + cp_rank = parallel_state.get_context_parallel_rank() + for key, val in batch.items(): + if val is not None: + seq_dim = 1 if key != 'attention_mask' else 2 + val = val.view( + *val.shape[0:seq_dim], + 2 * cp_size, + val.shape[seq_dim] // (2 * cp_size), + *val.shape[(seq_dim + 1) :], + ) + index = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True + ).cuda(non_blocking=True) + val = val.index_select(seq_dim, index) + val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) + batch[key] = val + + return batch diff --git a/megatron/training/utils.py b/megatron/training/utils.py index 4b3f2b683a..540400c0ba 100644 --- a/megatron/training/utils.py +++ b/megatron/training/utils.py @@ -36,7 +36,11 @@ from megatron.core import mpu from megatron.core.datasets.utils import get_blend_from_list from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate -from megatron.core.utils import get_data_parallel_group_if_dtensor, to_local_if_dtensor +from megatron.core.utils import ( + get_batch_on_this_cp_rank, + get_data_parallel_group_if_dtensor, + to_local_if_dtensor, +) from megatron.legacy.model import Float16Module from megatron.legacy.model.module import param_is_not_shared @@ -254,39 +258,6 @@ def get_ltor_masks_and_position_ids(data, return attention_mask, loss_mask, position_ids -def get_batch_on_this_cp_rank(batch): - """ Slice batch input along sequence dimension into multiple chunks, - which are parallelized across GPUs in a context parallel group. - """ - - # With causal masking, each token only attends to its prior tokens. Simply split - # sequence into CP chunks can result in severe load imbalance. That's to say, chunks - # at the end of sequence have bigger workload than others. To address this issue, - # we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0 - # and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so - # that we can get balanced workload among GPUs in a context parallel group. - args = get_args() - cp_size = args.context_parallel_size - if cp_size > 1: - cp_rank = mpu.get_context_parallel_rank() - for key, val in batch.items(): - if val is not None: - seq_dim = 1 if key != 'attention_mask' else 2 - val = val.view( - *val.shape[0:seq_dim], - 2 * cp_size, - val.shape[seq_dim] // (2 * cp_size), - *val.shape[(seq_dim + 1) :], - ) - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], - device="cpu", pin_memory=True).cuda(non_blocking=True) - val = val.index_select(seq_dim, index) - val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) - batch[key] = val - - return batch - - def print_rank_0(message): """If distributed is initialized, print only on rank 0.""" if torch.distributed.is_initialized():