From aa2a45dd44516925ba5c0579eb262caf48a81a1b Mon Sep 17 00:00:00 2001 From: Hongxiao Bai Date: Mon, 9 Dec 2024 05:29:57 -0800 Subject: [PATCH] ADLR/megatron-lm!2101 - Refactor MoE specs: move all submodules of MoELayer into the spec Co-authored-by: Zijie Yan --- megatron/core/models/gpt/gpt_layer_specs.py | 121 +++++++++--------- megatron/core/models/gpt/moe_module_specs.py | 81 ++++++++++++ megatron/core/transformer/moe/moe_layer.py | 23 +--- .../core/transformer/moe/shared_experts.py | 9 +- .../core/transformer/transformer_config.py | 4 + megatron/training/arguments.py | 2 + pretrain_gpt.py | 4 +- .../golden_values_dev.json | 58 ++++----- .../models/test_moe_experts.py | 20 ++- .../transformer/moe/test_grouped_mlp.py | 13 +- .../transformer/moe/test_moe_layer.py | 12 +- .../transformer/moe/test_routers.py | 4 +- .../transformer/moe/test_sequential_mlp.py | 4 +- .../transformer/moe/test_shared_experts.py | 6 +- .../transformer/moe/test_token_dispatcher.py | 4 +- .../transformer/moe/test_upcycling.py | 10 +- 16 files changed, 228 insertions(+), 147 deletions(-) create mode 100755 megatron/core/models/gpt/moe_module_specs.py diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 749be324ed..d0e48c190c 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -1,16 +1,16 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +import warnings from typing import Optional from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules from megatron.core.transformer.dot_product_attention import DotProductAttention from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules -from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules -from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.multi_latent_attention import ( MLASelfAttention, MLASelfAttentionSubmodules, @@ -26,12 +26,10 @@ try: from megatron.core.extensions.transformer_engine import ( - TEColumnParallelGroupedLinear, TEColumnParallelLinear, TEDotProductAttention, TELayerNormColumnParallelLinear, TENorm, - TERowParallelGroupedLinear, TERowParallelLinear, ) @@ -47,8 +45,6 @@ HAVE_APEX = True LNImpl = FusedLayerNorm except ImportError: - import warnings - from megatron.core.transformer.torch_norm import WrappedTorchNorm warnings.warn('Apex is not installed. Falling back to Torch Norm') @@ -60,7 +56,8 @@ def get_gpt_layer_with_transformer_engine_spec( moe_grouped_gemm: Optional[bool] = False, qk_layernorm: Optional[bool] = False, multi_latent_attention: Optional[bool] = False, - fp8: Optional[str] = None, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, ) -> ModuleSpec: """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). @@ -69,13 +66,24 @@ def get_gpt_layer_with_transformer_engine_spec( num_experts (int, optional): Number of experts. Defaults to None. moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. - fp8 (str, optional): Flag to decide the linear layer spec for MoE. Defaults to None. + fp8 (str, optional): Deprecated. For temporary Nemo compatibility. + moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. + Defaults to False. Returns: ModuleSpec: Module specification with TE modules """ + if fp8 is not None: + warnings.warn( + 'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated' + ' and will be removed soon. Please update your code accordingly.' + ) + mlp = _get_mlp_module_spec( - use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=fp8 + use_te=True, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, ) if multi_latent_attention: @@ -138,6 +146,8 @@ def get_gpt_layer_local_spec( moe_grouped_gemm: Optional[bool] = False, qk_layernorm: Optional[bool] = False, multi_latent_attention: Optional[bool] = False, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, ) -> ModuleSpec: """Use this spec for an implementation using only modules in Megatron-Core. @@ -146,13 +156,24 @@ def get_gpt_layer_local_spec( num_experts (int, optional): Number of experts. Defaults to None. moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False. qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False. + fp8 (str, optional): Deprecated. For temporary Nemo compatibility. + moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP. + Defaults to False. Returns: ModuleSpec: Module specification with Megatron-Core modules """ + if fp8 is not None: + warnings.warn( + 'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated' + ' and will be removed soon. Please update your code accordingly.' + ) mlp = _get_mlp_module_spec( - use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm + use_te=False, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, ) if multi_latent_attention: @@ -213,63 +234,33 @@ def _get_mlp_module_spec( use_te: Optional[bool] = True, num_experts: Optional[int] = None, moe_grouped_gemm: Optional[bool] = False, - fp8: Optional[str] = None, + fp8: Optional[str] = None, # pylint: disable=unused-arguments + moe_use_legacy_grouped_gemm: Optional[bool] = False, ) -> ModuleSpec: - """Helper function to get module spec for MLP""" - if num_experts is not None: - moe_spec = _get_moe_module_spec( - use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=fp8 + """Helper function to get module spec for MLP/MoE""" + if fp8 is not None: + warnings.warn( + 'The fp8 argument in "_get_mlp_module_spec" has been deprecated' + ' and will be removed soon. Please update your code accordingly.' ) - return moe_spec - - return ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear, - linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, - ), - ) - -def _get_moe_module_spec( - use_te: Optional[bool] = True, - num_experts: Optional[int] = None, - moe_grouped_gemm: Optional[bool] = False, - fp8: Optional[str] = None, -) -> ModuleSpec: - """Helper function to get module spec for MoE""" if num_experts is None: - return None - if use_te and moe_grouped_gemm: - linear_fc1 = TEColumnParallelGroupedLinear - linear_fc2 = TERowParallelGroupedLinear - elif use_te and fp8: - linear_fc1 = TEColumnParallelLinear - linear_fc2 = TERowParallelLinear - else: - linear_fc1 = ColumnParallelLinear - linear_fc2 = RowParallelLinear - - use_te_grouped_gemm = use_te and TEColumnParallelGroupedLinear is not None - - return ModuleSpec( - module=MoELayer, - submodules=MoESubmodules( - experts=( - MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2) - if not moe_grouped_gemm or use_te_grouped_gemm - else None - ), - shared_experts=ModuleSpec( - module=SharedExpertMLP, - params={"gate": False}, - submodules=MLPSubmodules( - linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, - linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, - ), + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, ), - ), - ) + ) + else: + # Mixture of experts with modules in megatron core. + return get_moe_module_spec( + use_te=use_te, + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm, + ) def get_gpt_decoder_block_spec( @@ -288,7 +279,7 @@ def get_gpt_decoder_block_spec( moe_grouped_gemm=False, qk_layernorm=config.qk_layernorm, multi_latent_attention=config.multi_latent_attention, - fp8=config.fp8, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, ) if use_transformer_engine else get_gpt_layer_local_spec( @@ -296,6 +287,7 @@ def get_gpt_decoder_block_spec( moe_grouped_gemm=False, qk_layernorm=config.qk_layernorm, multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, ) ) moe_layer_spec = ( @@ -304,7 +296,7 @@ def get_gpt_decoder_block_spec( moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm, multi_latent_attention=config.multi_latent_attention, - fp8=config.fp8, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, ) if use_transformer_engine else get_gpt_layer_local_spec( @@ -312,6 +304,7 @@ def get_gpt_decoder_block_spec( moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm, multi_latent_attention=config.multi_latent_attention, + moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm, ) ) diff --git a/megatron/core/models/gpt/moe_module_specs.py b/megatron/core/models/gpt/moe_module_specs.py new file mode 100755 index 0000000000..513eeddc7e --- /dev/null +++ b/megatron/core/models/gpt/moe_module_specs.py @@ -0,0 +1,81 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import warnings +from typing import Optional + +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.mlp import MLPSubmodules +from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.moe.shared_experts import SharedExpertMLP +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.utils import get_te_version, is_te_min_version + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelGroupedLinear, + TEColumnParallelLinear, + TERowParallelGroupedLinear, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + + +def get_moe_module_spec( + use_te: Optional[bool] = True, + num_experts: Optional[int] = None, + moe_grouped_gemm: Optional[bool] = False, + moe_use_legacy_grouped_gemm: Optional[bool] = False, +) -> ModuleSpec: + """Helper function to get module spec for MoE""" + assert num_experts is not None + + mlp = MLPSubmodules( + linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, + ) + + # experts spec + if moe_grouped_gemm: + ## use GroupedMLP + if use_te and TEColumnParallelGroupedLinear is not None and not moe_use_legacy_grouped_gemm: + ## use TEGroupedLinear + expert_module = TEGroupedMLP + expert_submodule = MLPSubmodules( + linear_fc1=TEColumnParallelGroupedLinear, linear_fc2=TERowParallelGroupedLinear + ) + else: + ## use legacy GroupedMLP + expert_module = GroupedMLP + expert_submodule = None + warnings.warn( + 'The legacy GroupedMLP will be deprecated in Megatron-Core v0.12.0. ' + 'Please update the TransformerEngine to version>=1.7.0 and use TEGroupedMLP.' + ) + else: + ## use SequentialMLP + expert_module = SequentialMLP + if use_te and not is_te_min_version("1.7.0.dev0"): + warnings.warn( + "Only transformer-engine>=1.7.0 supports MoE experts, " + f"but your version is {get_te_version()}. Use local linear implementation instead." + ) + expert_submodule = MLPSubmodules( + linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear + ) + else: + expert_submodule = mlp + + experts = ModuleSpec(module=expert_module, submodules=expert_submodule) + + # shared experts spec + shared_experts = ModuleSpec(module=SharedExpertMLP, params={"gate": False}, submodules=mlp) + + # MoE module spec + moe_module_spec = ModuleSpec( + module=MoELayer, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts) + ) + return moe_module_spec diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index faefce4cf0..ea0b0b11e5 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -9,15 +9,13 @@ from megatron.core import parallel_state, tensor_parallel from megatron.core.transformer.mlp import MLPSubmodules from megatron.core.transformer.module import MegatronModule -from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP from megatron.core.transformer.moe.legacy_a2a_token_dispatcher import MoEAlltoAllSEQTokenDispatcher from megatron.core.transformer.moe.router import TopKRouter -from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.moe.token_dispatcher import ( MoEAllGatherTokenDispatcher, MoEAlltoAllTokenDispatcher, ) -from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.spec_utils import ModuleSpec, build_module from megatron.core.transformer.transformer_config import TransformerConfig @@ -89,20 +87,6 @@ def __init__( # Initialize router self.router = TopKRouter(config=self.config) - # Initialize experts - if self.config.moe_grouped_gemm: - if isinstance(self.submodules.experts, MLPSubmodules): - self.experts = TEGroupedMLP( - self.num_local_experts, self.config, self.submodules.experts - ) - else: - self.experts = GroupedMLP(self.num_local_experts, self.config) - else: - assert isinstance(self.submodules.experts, MLPSubmodules) - self.experts = SequentialMLP( - self.num_local_experts, self.config, self.submodules.experts - ) - # Initialize token dispatcher if config.moe_token_dispatcher_type == "allgather": self.token_dispatcher = MoEAllGatherTokenDispatcher( @@ -121,9 +105,12 @@ def __init__( f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}" ) + # Initialize experts + self.experts = build_module(self.submodules.experts, self.num_local_experts, self.config) + # Initialize shared experts if self.use_shared_expert: - self.shared_experts = SharedExpertMLP(self.config, self.submodules.shared_experts) + self.shared_experts = build_module(self.submodules.shared_experts, config=self.config) if self.shared_expert_overlap: self.token_dispatcher.set_shared_experts(self.shared_experts) diff --git a/megatron/core/transformer/moe/shared_experts.py b/megatron/core/transformer/moe/shared_experts.py index 1d4b2a628f..7d1eaef705 100644 --- a/megatron/core/transformer/moe/shared_experts.py +++ b/megatron/core/transformer/moe/shared_experts.py @@ -17,8 +17,7 @@ reduce_from_tensor_model_parallel_region, reduce_scatter_to_sequence_parallel_region, ) -from megatron.core.transformer.mlp import MLP -from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.transformer_config import TransformerConfig from megatron.core.utils import is_torch_min_version, make_sharded_tensor_for_checkpoint @@ -32,15 +31,15 @@ class SharedExpertMLP(MLP): # The shared experts are scheduled into this stream to be overlapped with the dispatcher. stream = None - def __init__(self, config: TransformerConfig, spec: ModuleSpec): + def __init__(self, config: TransformerConfig, submodules: MLPSubmodules, gate: bool): config = deepcopy(config) assert config.add_bias_linear == False, "bias is not supported in the shared experts, " "please set '--disable-bias-linear' instead." config.ffn_hidden_size = config.moe_shared_expert_intermediate_size - super().__init__(config=config, submodules=spec.submodules) + super().__init__(config=config, submodules=submodules) - self.use_shared_expert_gate = spec.params.get("gate", False) + self.use_shared_expert_gate = gate if self.use_shared_expert_gate: # TODO: Add support for GPU initialization, which requires updating the golden values. self.gate_weight = torch.nn.Parameter(torch.empty((1, self.config.hidden_size))) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index cc56fd0978..855abbd59d 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -283,6 +283,10 @@ class TransformerConfig(ModelParallelConfig): GEMM feature introduced since CUTLASS 2.8 (https://github.com/fanshiqing/grouped_gemm). """ + moe_use_legacy_grouped_gemm: bool = False + """Use legacy GroupedMLP rather than TEGroupedMLP. + Note: The legacy one will be deprecated soon.""" + moe_aux_loss_coeff: float = 0 # 1e-2 would be a good start value for load balance loss. """Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended.""" diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 5d3f73f0f6..6e602add2c 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -2073,6 +2073,8 @@ def _add_moe_args(parser): help='Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.') group.add_argument('--moe-grouped-gemm', action='store_true', help='When there are multiple experts per rank, launch multiple local GEMM kernels in multiple streams to improve the utilization and performance with GroupedLinear in TransformerEngine.') + group.add_argument('--moe-use-legacy-grouped-gemm', action='store_true', + help='Use legacy GroupedMLP rather than TEGroupedMLP. Note: The legacy one will be deprecated soon.') group.add_argument('--moe-aux-loss-coeff', type=float, default=0.0, help='Scaling coefficient for the aux loss: a starting value of 1e-2 is recommended.') group.add_argument('--moe-z-loss-coeff', type=float, default=None, diff --git a/pretrain_gpt.py b/pretrain_gpt.py index 71c4767b5d..4d5bf9a767 100644 --- a/pretrain_gpt.py +++ b/pretrain_gpt.py @@ -89,11 +89,11 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat if use_te: transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( args.num_experts, args.moe_grouped_gemm, - args.qk_layernorm, args.multi_latent_attention, args.fp8) + args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm) else: transformer_layer_spec = get_gpt_layer_local_spec( args.num_experts, args.moe_grouped_gemm, - args.qk_layernorm, args.multi_latent_attention) + args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm) build_model_context = nullcontext build_model_context_args = {} diff --git a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_te_8experts2parallel_dist_optimizer_dgx_a100_1N8G/golden_values_dev.json b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_te_8experts2parallel_dist_optimizer_dgx_a100_1N8G/golden_values_dev.json index a09763fbe5..6ba3300b83 100644 --- a/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_te_8experts2parallel_dist_optimizer_dgx_a100_1N8G/golden_values_dev.json +++ b/tests/functional_tests/test_cases/gpt/gpt3_mr_mcore_te_tp2_pp1_te_8experts2parallel_dist_optimizer_dgx_a100_1N8G/golden_values_dev.json @@ -5,15 +5,15 @@ "step_interval": 5, "values": [ 10.79987, - 10.85947, - 10.86478, - 10.80039, - 10.70971, - 10.63893, - 10.19526, - 10.31102, - 10.22247, - 9.91425 + 10.85907, + 10.86575, + 10.79932, + 10.70961, + 10.63871, + 10.19492, + 10.31016, + 10.22301, + 9.91473 ] }, "num-zeros": { @@ -21,16 +21,16 @@ "end_step": 50, "step_interval": 5, "values": [ - 30798.0, - 37696.0, - 37844.0, - 36275.0, - 33140.0, - 35137.0, - 30638.0, - 35309.0, - 36677.0, - 37604.0 + 30795.0, + 37447.0, + 37837.0, + 35948.0, + 33382.0, + 34774.0, + 30403.0, + 35340.0, + 36357.0, + 37792.0 ] }, "iteration-time": { @@ -38,16 +38,16 @@ "end_step": 50, "step_interval": 5, "values": [ - 12.59746, - 0.61072, - 0.61063, - 0.61049, - 0.61015, - 0.60932, - 0.61233, - 0.61024, - 0.61226, - 0.61621 + 10.77572, + 0.42536, + 0.42839, + 0.42977, + 0.42283, + 0.42333, + 0.43199, + 0.42998, + 0.43124, + 0.43207 ] } } \ No newline at end of file diff --git a/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py b/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py index e5e3ac98bd..54a60fc62a 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py +++ b/tests/unit_tests/dist_checkpointing/models/test_moe_experts.py @@ -15,7 +15,10 @@ FullyParallelLoadStrategyWrapper, FullyParallelSaveStrategyWrapper, ) -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP from megatron.core.transformer.transformer_config import TransformerConfig @@ -43,22 +46,25 @@ def initialize_expert_layer(seed, glu=True, expert_type='sequential', fp8=False, ) default_config_kwargs.update(**config_kwargs) transformer_config = TransformerConfig(**default_config_kwargs) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - num_experts=num_moe_experts, moe_grouped_gemm=(expert_type != 'sequential'), fp8=fp8 - ) if expert_type == 'grouped': model = GroupedMLP(num_local_experts, transformer_config) elif expert_type == 'te_grouped': + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=num_moe_experts, moe_grouped_gemm=True + ) model = TEGroupedMLP( num_local_experts, transformer_config, - transformer_layer_spec.submodules.mlp.submodules.experts, + transformer_layer_spec.submodules.mlp.submodules.experts.submodules, ) elif expert_type == 'sequential': + transformer_layer_spec = get_gpt_layer_local_spec( + num_experts=num_moe_experts, moe_grouped_gemm=False + ) model = SequentialMLP( num_local_experts, transformer_config, - transformer_layer_spec.submodules.mlp.submodules.experts, + transformer_layer_spec.submodules.mlp.submodules.experts.submodules, ) else: raise ValueError('expert_type can only be one of ["sequential", "grouped", "te_grouped"]') @@ -86,6 +92,7 @@ def setup_method(self, method): def teardown_method(self, method): Utils.destroy_model_parallel() + @pytest.mark.internal @pytest.mark.parametrize( "use_fpsl,src_tp_pp_ep_etp,dest_tp_pp_ep_etp,use_glu", [ @@ -200,6 +207,7 @@ def test_parallel_reconfiguration_e2e( diffs = diff(state_dict_A, state_dict_B) assert not any(map(bool, diffs)), diffs + @pytest.mark.internal @pytest.mark.parametrize( "src_tp_pp_exp,dest_tp_pp_exp,use_glu", [ diff --git a/tests/unit_tests/transformer/moe/test_grouped_mlp.py b/tests/unit_tests/transformer/moe/test_grouped_mlp.py index 2c27549325..c7c4935976 100644 --- a/tests/unit_tests/transformer/moe/test_grouped_mlp.py +++ b/tests/unit_tests/transformer/moe/test_grouped_mlp.py @@ -4,7 +4,10 @@ import torch import torch.nn.functional as F -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, +) from megatron.core.transformer.moe import grouped_gemm_util as gg from megatron.core.transformer.moe.experts import TEGroupedMLP from megatron.core.transformer.moe.moe_layer import MoELayer @@ -66,9 +69,7 @@ def setup_method(self, method, use_cpu_initialization=False, swiglu=True): ## Vanilla sequential GEMM # Set random seed for reproducability _set_random_seed(seed_=123, data_parallel_random_init=False) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - self.num_experts, moe_grouped_gemm=False - ) + transformer_layer_spec = get_gpt_layer_local_spec(self.num_experts, moe_grouped_gemm=False) self.sequential_mlp = MoELayer(tf_config, transformer_layer_spec.submodules.mlp.submodules) self.args = parse_args(ignore_unknown_args=True) @@ -254,9 +255,7 @@ def setup_method(self, method, use_cpu_initialization=False, swiglu=True): ## Vanilla sequential GEMM # Set random seed for reproducability _set_random_seed(seed_=123, data_parallel_random_init=False) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( - self.num_experts, moe_grouped_gemm=False - ) + transformer_layer_spec = get_gpt_layer_local_spec(self.num_experts, moe_grouped_gemm=False) self.sequential_mlp = MoELayer(tf_config, transformer_layer_spec.submodules.mlp.submodules) self.args = parse_args(ignore_unknown_args=True) diff --git a/tests/unit_tests/transformer/moe/test_moe_layer.py b/tests/unit_tests/transformer/moe/test_moe_layer.py index d303a3f3e9..59afadfd20 100644 --- a/tests/unit_tests/transformer/moe/test_moe_layer.py +++ b/tests/unit_tests/transformer/moe/test_moe_layer.py @@ -13,6 +13,7 @@ from megatron.core.transformer.moe.router import Router from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import is_te_min_version from megatron.training.initialize import _set_random_seed from tests.unit_tests.test_utilities import Utils @@ -21,6 +22,10 @@ class TestMoELayerInit: def setup_method(self, method): pass + @pytest.mark.skipif( + not is_te_min_version("1.7.0.dev0"), + reason="Expert with TE Linear is only supported in TE 1.7.0 and later.", + ) @pytest.mark.parametrize("moe_token_dispatcher_type", ["allgather", "alltoall"]) @pytest.mark.parametrize("num_moe_experts", [1, 2]) @pytest.mark.parametrize("grouped_gemm", [True, False]) @@ -49,7 +54,8 @@ def test_te_moe_layer(self, num_moe_experts, moe_token_dispatcher_type, grouped_ @pytest.mark.parametrize("moe_token_dispatcher_type", ["allgather", "alltoall"]) @pytest.mark.parametrize("num_moe_experts", [1, 2]) - def test_legacy_moe_layer(self, num_moe_experts, moe_token_dispatcher_type): + @pytest.mark.parametrize("grouped_gemm", [True, False]) + def test_legacy_moe_layer(self, num_moe_experts, moe_token_dispatcher_type, grouped_gemm): Utils.initialize_model_parallel(1, 1) _set_random_seed(seed_=123, data_parallel_random_init=False) num_moe_experts = 4 @@ -59,13 +65,15 @@ def test_legacy_moe_layer(self, num_moe_experts, moe_token_dispatcher_type): num_attention_heads=4, num_moe_experts=num_moe_experts, use_cpu_initialization=True, + moe_token_dispatcher_type=moe_token_dispatcher_type, moe_router_load_balancing_type="aux_loss", moe_router_topk=2, moe_aux_loss_coeff=0.01, + moe_grouped_gemm=grouped_gemm, add_bias_linear=False, ) transformer_layer_spec = get_gpt_layer_local_spec( - num_experts=num_moe_experts, moe_grouped_gemm=False + num_experts=num_moe_experts, moe_grouped_gemm=grouped_gemm ) moe_layer = MoELayer( self.transformer_config, transformer_layer_spec.submodules.mlp.submodules diff --git a/tests/unit_tests/transformer/moe/test_routers.py b/tests/unit_tests/transformer/moe/test_routers.py index 65796ff599..b146560090 100644 --- a/tests/unit_tests/transformer/moe/test_routers.py +++ b/tests/unit_tests/transformer/moe/test_routers.py @@ -3,7 +3,7 @@ import pytest import torch -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.router import Router from megatron.core.transformer.transformer_config import TransformerConfig @@ -27,7 +27,7 @@ def setup_method(self, method): moe_router_topk=2, moe_aux_loss_coeff=0, ) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + transformer_layer_spec = get_gpt_layer_local_spec( num_experts=num_moe_experts, moe_grouped_gemm=False ) self.sequential_mlp = MoELayer( diff --git a/tests/unit_tests/transformer/moe/test_sequential_mlp.py b/tests/unit_tests/transformer/moe/test_sequential_mlp.py index 2a005555d5..dc350e092b 100644 --- a/tests/unit_tests/transformer/moe/test_sequential_mlp.py +++ b/tests/unit_tests/transformer/moe/test_sequential_mlp.py @@ -5,7 +5,7 @@ import torch from megatron.core.extensions.transformer_engine import TEColumnParallelLinear, TERowParallelLinear -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.mlp import MLPSubmodules @@ -35,7 +35,7 @@ def setup_method(self, method): moe_router_load_balancing_type="sinkhorn", moe_router_topk=1, ) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + transformer_layer_spec = get_gpt_layer_local_spec( num_experts=num_moe_experts, moe_grouped_gemm=False ) self.sequential_mlp = MoELayer( diff --git a/tests/unit_tests/transformer/moe/test_shared_experts.py b/tests/unit_tests/transformer/moe/test_shared_experts.py index 0cacf30836..f721c48293 100644 --- a/tests/unit_tests/transformer/moe/test_shared_experts.py +++ b/tests/unit_tests/transformer/moe/test_shared_experts.py @@ -3,7 +3,7 @@ import pytest import torch -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.transformer_config import TransformerConfig @@ -39,7 +39,7 @@ def test_gpu_forward(self): moe_router_topk=1, add_bias_linear=False, ) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + transformer_layer_spec = get_gpt_layer_local_spec( num_experts=num_moe_experts, moe_grouped_gemm=False ) self.moe_layer = MoELayer( @@ -98,7 +98,7 @@ def test_gpu_forward(self): moe_router_topk=1, add_bias_linear=False, ) - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + transformer_layer_spec = get_gpt_layer_local_spec( num_experts=num_moe_experts, moe_grouped_gemm=False ) self.moe_layer = MoELayer( diff --git a/tests/unit_tests/transformer/moe/test_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_token_dispatcher.py index 895cb291aa..f8463042b7 100644 --- a/tests/unit_tests/transformer/moe/test_token_dispatcher.py +++ b/tests/unit_tests/transformer/moe/test_token_dispatcher.py @@ -6,7 +6,7 @@ import torch from megatron.core import parallel_state -from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.moe.moe_utils import permute, unpermute from megatron.core.transformer.transformer_config import TransformerConfig @@ -75,7 +75,7 @@ def __init__( self.moe_layer = self.new_moe_layer() def new_moe_layer(self): - transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + transformer_layer_spec = get_gpt_layer_local_spec( num_experts=self.config.num_moe_experts, moe_grouped_gemm=self.config.moe_grouped_gemm ) moe_layer = MoELayer( diff --git a/tests/unit_tests/transformer/moe/test_upcycling.py b/tests/unit_tests/transformer/moe/test_upcycling.py index fc53d57ad1..5b5610eb33 100644 --- a/tests/unit_tests/transformer/moe/test_upcycling.py +++ b/tests/unit_tests/transformer/moe/test_upcycling.py @@ -7,9 +7,7 @@ from megatron.core import mpu from megatron.core.enums import ModelType -from megatron.core.models.gpt.gpt_layer_specs import ( - get_gpt_layer_with_transformer_engine_spec as gpt_te_spec, -) +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec from megatron.core.models.gpt.gpt_model import GPTModel from megatron.core.num_microbatches_calculator import destroy_num_microbatches_calculator from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed @@ -32,7 +30,9 @@ _SEED = 42 -def model_provider(pre_process=True, post_process=True, layer_spec_fn=gpt_te_spec, **config_kwargs): +def model_provider( + pre_process=True, post_process=True, layer_spec_fn=get_gpt_layer_local_spec, **config_kwargs +): model_parallel_cuda_manual_seed(_SEED) args = get_args() @@ -40,7 +40,7 @@ def model_provider(pre_process=True, post_process=True, layer_spec_fn=gpt_te_spe model = GPTModel( config=config, - transformer_layer_spec=gpt_te_spec( + transformer_layer_spec=layer_spec_fn( args.num_experts, args.moe_grouped_gemm, args.qk_layernorm ), vocab_size=args.vocal_size,