diff --git a/megatron/core/model_parallel_config.py b/megatron/core/model_parallel_config.py index caae41cb4a..f2751673e4 100644 --- a/megatron/core/model_parallel_config.py +++ b/megatron/core/model_parallel_config.py @@ -182,8 +182,8 @@ class ModelParallelConfig: tp_comm_atomic_ag: bool = False """Deprecated from TransformerEngine v1.6.0. - If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather both - done atomically. Don't care if tp_comm_overlap is False. + If true, allows All-Gather overlap with Fprop GEMM by pipelining the GEMM and All-Gather + both done atomically. Don't care if tp_comm_overlap is False. """ tp_comm_split_rs: bool = True @@ -213,6 +213,11 @@ class ModelParallelConfig: If true, the AllGather -> Gemm overlap for FC1 layer of MLP gets disabled """ + tp_comm_bootstrap_backend: str = 'nccl' + """ + Set the bootstrapping backend out of 'nccl', 'mpi', and 'gloo' + """ + ################### # Pipeline Parallel ################### @@ -257,7 +262,8 @@ class ModelParallelConfig: wgrad_deferral_limit: int = 0 """This value tunes the number of micro-batches for which the embedding weight gradient compute - needs to be deferred to pipeline flush, this argument is invalid if `defer_embedding_wgrad_compute` is False. + needs to be deferred to pipeline flush, this argument is invalid if + `defer_embedding_wgrad_compute` is False. Defaults to 0, which means all micro-batches are deferred. """ @@ -276,7 +282,9 @@ class ModelParallelConfig: """Tells the number of transformer layers for which activations has to be offloaded.""" _cpu_offloading_context: ContextManager = ( - None # Used for internal use only, not to be set by the user. TODO: Need to move to the 'right' place when possible. + None + # Used for internal use only, not to be set by a user. + # TODO: Need to move to the 'right' place when possible. ) """For internal use only, do not set.""" @@ -297,7 +305,8 @@ class ModelParallelConfig: def __post_init__(self): """Python dataclass method that is used to modify attributes after initialization. - See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more details. + See https://docs.python.org/3/library/dataclasses.html#post-init-processing for more + details. """ if self.sequence_parallel: if self.tensor_model_parallel_size <= 1: @@ -324,11 +333,12 @@ def __post_init__(self): if self.defer_embedding_wgrad_compute and self.wgrad_deferral_limit < 0: raise ValueError( - "Wgrad deferral limit should be greater than or equal to 0 when this optimization is enabled!" + "Wgrad deferral limit should be greater than or equal to 0 when it is enabled!" ) if self.expert_model_parallel_size > 1 and self.tensor_model_parallel_size > 1: if self.sequence_parallel is False: raise ValueError( - "When using expert parallelism and tensor parallelism, sequence parallelism must be used" + "When using expert parallelism and tensor parallelism, sequence parallelism " + "must be used" ) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 162d719314..e3d876a5f2 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1160,6 +1160,9 @@ def _add_training_args(parser): group.add_argument('--disable-tp-comm-bulk-wgrad', action='store_false', help='Disables the Reduce-Scatter overlap with bprop weight gradient GEMM.', dest='tp_comm_bulk_wgrad') + group.add_argument('--tp-comm-bootstrap-backend', default='nccl', type=str, + choices=['nccl', 'mpi', 'gloo'], + help='Set the bootstrapping backend of Tensor parallel communications.') group.add_argument('--use-cpu-initialization', action='store_true', default=None, help='If set, initialize weights on the CPU. This eliminates init differences based on tensor parallelism.') diff --git a/megatron/training/initialize.py b/megatron/training/initialize.py index 8e4877c8b5..ad68ce8cb7 100644 --- a/megatron/training/initialize.py +++ b/megatron/training/initialize.py @@ -22,6 +22,7 @@ from megatron.core.fusions.fused_bias_dropout import bias_dropout_add_fused_train from megatron.core.fusions.fused_bias_gelu import bias_gelu from megatron.core.fusions.fused_bias_swiglu import bias_swiglu +from megatron.core.utils import get_te_version, is_te_min_version logger = logging.getLogger(__name__) @@ -211,12 +212,21 @@ def _initialize_tp_communicators(): input_shape = [(args.seq_length * args.micro_batch_size) // args.context_parallel_size , args.hidden_size] - #We create a MPI process group, which is needed to bootstrap the pipelined - #tensor-model-parallel communication overlap - torch.distributed.new_group(backend='mpi') - - te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size, - use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs,) + if is_te_min_version("1.9.0"): + # The process group with the target bootstrap backend is created in Transformer Engine. + te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size, + use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs, + bootstrap_backend = args.tp_comm_bootstrap_backend) + else: + if args.tp_comm_bootstrap_backend != 'mpi': + warnings.warn( + f"Transformer Engine v{get_te_version()} supports only MPI bootstrap backend." + ) + # Create a MPI process group to help with TP communication overlap bootstrap. + torch.distributed.new_group(backend='mpi') + + te_module.base.initialize_ub(shape = input_shape, tp_size = args.tensor_model_parallel_size, + use_fp8 = (args.fp8 is not None) , ub_cfgs = ub_cfgs) def _initialize_distributed(get_embedding_ranks, get_position_embedding_ranks): """Initialize torch.distributed and core model parallel."""