Skip to content

Commit

Permalink
Merge branch 'tp_bootstrap_backend' into 'main'
Browse files Browse the repository at this point in the history
Add the interface to set TP communication bootstrap backend

See merge request ADLR/megatron-lm!2153
  • Loading branch information
ericharper committed Oct 3, 2024
2 parents 065260b + f76b465 commit 25f7da2
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 13 deletions.
24 changes: 17 additions & 7 deletions megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ class ModelParallelConfig:

tp_comm_atomic_ag: bool = False
"""Deprecated from TransformerEngine v1.6.0.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather both
done atomically. Don't care if tp_comm_overlap is False.
If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather
both done atomically. Don't care if tp_comm_overlap is False.
"""

tp_comm_split_rs: bool = True
Expand Down Expand Up @@ -213,6 +213,11 @@ class ModelParallelConfig:
If true, the AllGather -> Gemm overlap for FC1 layer of MLP gets disabled
"""

tp_comm_bootstrap_backend: str = 'nccl'
"""
Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo'
"""

###################
# Pipeline Parallel
###################
Expand Down Expand Up @@ -257,7 +262,8 @@ class ModelParallelConfig:

wgrad_deferral_limit: int = 0
"""This value tunes the number of micro-batches for which the embedding weight gradient compute
needs to be deferred to pipeline flush, this argument is invalid if `defer_embedding_wgrad_compute` is False.
needs to be deferred to pipeline flush, this argument is invalid if
`defer_embedding_wgrad_compute` is False.
Defaults to 0, which means all micro-batches are deferred.
"""

Expand All @@ -276,7 +282,9 @@ class ModelParallelConfig:
"""Tells the number of transformer layers for which activations has to be offloaded."""

_cpu_offloading_context: ContextManager = (
None # Used for internal use only, not to be set by the user. TODO: Need to move to the 'right' place when possible.
None
# Used for internal use only, not to be set by a user.
# TODO: Need to move to the 'right' place when possible.
)
"""For internal use only, do not set."""

Expand All @@ -297,7 +305,8 @@ class ModelParallelConfig:

def __post_init__(self):
"""Python dataclass method that is used to modify attributes after initialization.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details.
See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more
details.
"""
if self.sequence_parallel:
if self.tensor_model_parallel_size <= 1:
Expand All @@ -324,11 +333,12 @@ def __post_init__(self):

if self.defer_embedding_wgrad_compute and self.wgrad_deferral_limit < 0:
raise ValueError(
"Wgrad deferral limit should be greater than or equal to 0 when this optimization is enabled!"
"Wgrad deferral limit should be greater than or equal to 0 when it is enabled!"
)

if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1:
if self.sequence_parallel is False:
raise ValueError(
"When using expert parallelism and tensor parallelism, sequence parallelism must be used"
"When using expert parallelism and tensor parallelism, sequence parallelism "
"must be used"
)
3 changes: 3 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,9 @@ def _add_training_args(parser):
group.add_argument('--disable-tp-comm-bulk-wgrad', action='store_false',
help='Disables the Reduce-Scatter overlap with bprop weight gradient GEMM.',
dest='tp_comm_bulk_wgrad')
group.add_argument('--tp-comm-bootstrap-backend', default='nccl', type=str,
choices=['nccl', 'mpi', 'gloo'],
help='Set the bootstrapping backend of Tensor parallel communications.')
group.add_argument('--use-cpu-initialization', action='store_true',
default=None,
help='If set, initialize weights on the CPU. This eliminates init differences based on tensor parallelism.')
Expand Down
22 changes: 16 additions & 6 deletions megatron/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from megatron.core.fusions.fused_bias_dropout import bias_dropout_add_fused_train
from megatron.core.fusions.fused_bias_gelu import bias_gelu
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu
from megatron.core.utils import get_te_version, is_te_min_version

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -211,12 +212,21 @@ def _initialize_tp_communicators():

input_shape = [(args.seq_length * args.micro_batch_size) // args.context_parallel_size , args.hidden_size]

#We create a MPI process group, which is needed to bootstrap the pipelined
#tensor-model-parallel communication overlap
torch.distributed.new_group(backend='mpi')

te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size,
use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs,)
if is_te_min_version("1.9.0"):
# The process group with the target bootstrap backend is created in Transformer Engine.
te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size,
use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs,
bootstrap_backend = args.tp_comm_bootstrap_backend)
else:
if args.tp_comm_bootstrap_backend != 'mpi':
warnings.warn(
f"Transformer Engine v{get_te_version()} supports only MPI bootstrap backend."
)
# Create a MPI process group to help with TP communication overlap bootstrap.
torch.distributed.new_group(backend='mpi')

te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size,
use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs)

def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks):
"""Initialize torch.distributed and core model parallel."""
Expand Down

0 comments on commit 25f7da2

Please sign in to comment.