Skip to content

Commit

Permalink
ADLR/megatron-lm!2404 - move get_batch_on_this_cp_rank to mcore utils
Browse files Browse the repository at this point in the history
  • Loading branch information
xrennvidia authored and ko3n1g committed Dec 11, 2024
1 parent f3e1afb commit 40fb590
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 37 deletions.
4 changes: 1 addition & 3 deletions megatron/core/models/multimodal/llava_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from megatron.core.transformer import MegatronModule
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import log_single_rank
from megatron.core.utils import get_batch_on_this_cp_rank, log_single_rank

try:
import transformer_engine # pylint: disable=unused-import
Expand Down Expand Up @@ -636,8 +636,6 @@ def _process_embedding_token_parallel(

if self.context_parallel_lm > 1:
# Distribute sequence across CP ranks
from megatron.training.utils import get_batch_on_this_cp_rank

batch = get_batch_on_this_cp_rank(
{
"combined_embeddings": combined_embeddings,
Expand Down
38 changes: 38 additions & 0 deletions megatron/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,3 +1413,41 @@ def __exit__(
def is_float8tensor(tensor: torch.Tensor) -> bool:
"""Check if a tensor is a Transformer Engine Float8Tensor"""
return HAVE_TE_FLOAT8TENSOR and isinstance(tensor, Float8Tensor)


########################
### context parallel ###
########################


def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
"""Slice batch input along sequence dimension into multiple chunks,
which are parallelized across GPUs in a context parallel group.
"""

# With causal masking, each token only attends to its prior tokens. Simply split
# sequence into CP chunks can result in severe load imbalance. That's to say, chunks
# at the end of sequence have bigger workload than others. To address this issue,
# we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
# and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
# that we can get balanced workload among GPUs in a context parallel group.
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor(
[cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True
).cuda(non_blocking=True)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val

return batch
39 changes: 5 additions & 34 deletions megatron/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
from megatron.core import mpu
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.tensor_parallel import param_is_not_tensor_parallel_duplicate
from megatron.core.utils import get_data_parallel_group_if_dtensor, to_local_if_dtensor
from megatron.core.utils import (
get_batch_on_this_cp_rank,
get_data_parallel_group_if_dtensor,
to_local_if_dtensor,
)
from megatron.legacy.model import Float16Module
from megatron.legacy.model.module import param_is_not_shared

Expand Down Expand Up @@ -254,39 +258,6 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids


def get_batch_on_this_cp_rank(batch):
""" Slice batch input along sequence dimension into multiple chunks,
which are parallelized across GPUs in a context parallel group.
"""

# With causal masking, each token only attends to its prior tokens. Simply split
# sequence into CP chunks can result in severe load imbalance. That's to say, chunks
# at the end of sequence have bigger workload than others. To address this issue,
# we split sequence into 2*CP ranks. Assuming CP=2, we then get 4 chunks, chunk_0
# and chunk_3 are assigned to GPU0, chunk_1 and chunk_2 are assigned to GPU1, so
# that we can get balanced workload among GPUs in a context parallel group.
args = get_args()
cp_size = args.context_parallel_size
if cp_size > 1:
cp_rank = mpu.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)],
device="cpu", pin_memory=True).cuda(non_blocking=True)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val

return batch


def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
Expand Down

0 comments on commit 40fb590

Please sign in to comment.