Skip to content

Commit

Permalink
ADLR/megatron-lm!2101 - Refactor MoE specs: move all submodules of Mo…
Browse files Browse the repository at this point in the history
…ELayer into the spec

Co-authored-by: Zijie Yan <[email protected]>
  • Loading branch information
2 people authored and ko3n1g committed Dec 9, 2024
1 parent 9665f2d commit aa2a45d
Show file tree
Hide file tree
Showing 16 changed files with 228 additions and 147 deletions.
121 changes: 57 additions & 64 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import warnings
from typing import Optional

from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.models.gpt.moe_module_specs import get_moe_module_spec
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules
from megatron.core.transformer.dot_product_attention import DotProductAttention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
from megatron.core.transformer.multi_latent_attention import (
MLASelfAttention,
MLASelfAttentionSubmodules,
Expand All @@ -26,12 +26,10 @@

try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelGroupedLinear,
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
TERowParallelGroupedLinear,
TERowParallelLinear,
)

Expand All @@ -47,8 +45,6 @@
HAVE_APEX = True
LNImpl = FusedLayerNorm
except ImportError:
import warnings

from megatron.core.transformer.torch_norm import WrappedTorchNorm

warnings.warn('Apex is not installed. Falling back to Torch Norm')
Expand All @@ -60,7 +56,8 @@ def get_gpt_layer_with_transformer_engine_spec(
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
Expand All @@ -69,13 +66,24 @@ def get_gpt_layer_with_transformer_engine_spec(
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Flag to decide the linear layer spec for MoE. Defaults to None.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with TE modules
"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_with_transformer_engine_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)

mlp = _get_mlp_module_spec(
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=fp8
use_te=True,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)

if multi_latent_attention:
Expand Down Expand Up @@ -138,6 +146,8 @@ def get_gpt_layer_local_spec(
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
multi_latent_attention: Optional[bool] = False,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Use this spec for an implementation using only modules in Megatron-Core.
Expand All @@ -146,13 +156,24 @@ def get_gpt_layer_local_spec(
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Deprecated. For temporary Nemo compatibility.
moe_use_legacy_grouped_gemm (bool, optional): Force use the legacy GroupedMLP.
Defaults to False.
Returns:
ModuleSpec: Module specification with Megatron-Core modules
"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "get_gpt_layer_local_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)

mlp = _get_mlp_module_spec(
use_te=False, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
use_te=False,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)

if multi_latent_attention:
Expand Down Expand Up @@ -213,63 +234,33 @@ def _get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None,
fp8: Optional[str] = None, # pylint: disable=unused-arguments
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Helper function to get module spec for MLP"""
if num_experts is not None:
moe_spec = _get_moe_module_spec(
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=fp8
"""Helper function to get module spec for MLP/MoE"""
if fp8 is not None:
warnings.warn(
'The fp8 argument in "_get_mlp_module_spec" has been deprecated'
' and will be removed soon. Please update your code accordingly.'
)
return moe_spec

return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
)


def _get_moe_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None,
) -> ModuleSpec:
"""Helper function to get module spec for MoE"""
if num_experts is None:
return None
if use_te and moe_grouped_gemm:
linear_fc1 = TEColumnParallelGroupedLinear
linear_fc2 = TERowParallelGroupedLinear
elif use_te and fp8:
linear_fc1 = TEColumnParallelLinear
linear_fc2 = TERowParallelLinear
else:
linear_fc1 = ColumnParallelLinear
linear_fc2 = RowParallelLinear

use_te_grouped_gemm = use_te and TEColumnParallelGroupedLinear is not None

return ModuleSpec(
module=MoELayer,
submodules=MoESubmodules(
experts=(
MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2)
if not moe_grouped_gemm or use_te_grouped_gemm
else None
),
shared_experts=ModuleSpec(
module=SharedExpertMLP,
params={"gate": False},
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
# Dense MLP w/ or w/o TE modules.
return ModuleSpec(
module=MLP,
submodules=MLPSubmodules(
linear_fc1=TELayerNormColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
),
)
)
else:
# Mixture of experts with modules in megatron core.
return get_moe_module_spec(
use_te=use_te,
num_experts=num_experts,
moe_grouped_gemm=moe_grouped_gemm,
moe_use_legacy_grouped_gemm=moe_use_legacy_grouped_gemm,
)


def get_gpt_decoder_block_spec(
Expand All @@ -288,14 +279,15 @@ def get_gpt_decoder_block_spec(
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
fp8=config.fp8,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
if use_transformer_engine
else get_gpt_layer_local_spec(
num_experts=None,
moe_grouped_gemm=False,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
)
moe_layer_spec = (
Expand All @@ -304,14 +296,15 @@ def get_gpt_decoder_block_spec(
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
fp8=config.fp8,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
if use_transformer_engine
else get_gpt_layer_local_spec(
num_experts=config.num_moe_experts,
moe_grouped_gemm=config.moe_grouped_gemm,
qk_layernorm=config.qk_layernorm,
multi_latent_attention=config.multi_latent_attention,
moe_use_legacy_grouped_gemm=config.moe_use_legacy_grouped_gemm,
)
)

Expand Down
81 changes: 81 additions & 0 deletions megatron/core/models/gpt/moe_module_specs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

import warnings
from typing import Optional

from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
from megatron.core.transformer.mlp import MLPSubmodules
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP
from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.utils import get_te_version, is_te_min_version

try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelGroupedLinear,
TEColumnParallelLinear,
TERowParallelGroupedLinear,
TERowParallelLinear,
)

HAVE_TE = True
except ImportError:
HAVE_TE = False


def get_moe_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
moe_use_legacy_grouped_gemm: Optional[bool] = False,
) -> ModuleSpec:
"""Helper function to get module spec for MoE"""
assert num_experts is not None

mlp = MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
)

# experts spec
if moe_grouped_gemm:
## use GroupedMLP
if use_te and TEColumnParallelGroupedLinear is not None and not moe_use_legacy_grouped_gemm:
## use TEGroupedLinear
expert_module = TEGroupedMLP
expert_submodule = MLPSubmodules(
linear_fc1=TEColumnParallelGroupedLinear, linear_fc2=TERowParallelGroupedLinear
)
else:
## use legacy GroupedMLP
expert_module = GroupedMLP
expert_submodule = None
warnings.warn(
'The legacy GroupedMLP will be deprecated in Megatron-Core v0.12.0. '
'Please update the TransformerEngine to version>=1.7.0 and use TEGroupedMLP.'
)
else:
## use SequentialMLP
expert_module = SequentialMLP
if use_te and not is_te_min_version("1.7.0.dev0"):
warnings.warn(
"Only transformer-engine>=1.7.0 supports MoE experts, "
f"but your version is {get_te_version()}. Use local linear implementation instead."
)
expert_submodule = MLPSubmodules(
linear_fc1=ColumnParallelLinear, linear_fc2=RowParallelLinear
)
else:
expert_submodule = mlp

experts = ModuleSpec(module=expert_module, submodules=expert_submodule)

# shared experts spec
shared_experts = ModuleSpec(module=SharedExpertMLP, params={"gate": False}, submodules=mlp)

# MoE module spec
moe_module_spec = ModuleSpec(
module=MoELayer, submodules=MoESubmodules(experts=experts, shared_experts=shared_experts)
)
return moe_module_spec
23 changes: 5 additions & 18 deletions megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
from megatron.core import parallel_state, tensor_parallel
from megatron.core.transformer.mlp import MLPSubmodules
from megatron.core.transformer.module import MegatronModule
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP
from megatron.core.transformer.moe.legacy_a2a_token_dispatcher import MoEAlltoAllSEQTokenDispatcher
from megatron.core.transformer.moe.router import TopKRouter
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
from megatron.core.transformer.moe.token_dispatcher import (
MoEAllGatherTokenDispatcher,
MoEAlltoAllTokenDispatcher,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
from megatron.core.transformer.transformer_config import TransformerConfig


Expand Down Expand Up @@ -89,20 +87,6 @@ def __init__(
# Initialize router
self.router = TopKRouter(config=self.config)

# Initialize experts
if self.config.moe_grouped_gemm:
if isinstance(self.submodules.experts, MLPSubmodules):
self.experts = TEGroupedMLP(
self.num_local_experts, self.config, self.submodules.experts
)
else:
self.experts = GroupedMLP(self.num_local_experts, self.config)
else:
assert isinstance(self.submodules.experts, MLPSubmodules)
self.experts = SequentialMLP(
self.num_local_experts, self.config, self.submodules.experts
)

# Initialize token dispatcher
if config.moe_token_dispatcher_type == "allgather":
self.token_dispatcher = MoEAllGatherTokenDispatcher(
Expand All @@ -121,9 +105,12 @@ def __init__(
f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}"
)

# Initialize experts
self.experts = build_module(self.submodules.experts, self.num_local_experts, self.config)

# Initialize shared experts
if self.use_shared_expert:
self.shared_experts = SharedExpertMLP(self.config, self.submodules.shared_experts)
self.shared_experts = build_module(self.submodules.shared_experts, config=self.config)
if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(self.shared_experts)

Expand Down
9 changes: 4 additions & 5 deletions megatron/core/transformer/moe/shared_experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
reduce_from_tensor_model_parallel_region,
reduce_scatter_to_sequence_parallel_region,
)
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.utils import is_torch_min_version, make_sharded_tensor_for_checkpoint

Expand All @@ -32,15 +31,15 @@ class SharedExpertMLP(MLP):
# The shared experts are scheduled into this stream to be overlapped with the dispatcher.
stream = None

def __init__(self, config: TransformerConfig, spec: ModuleSpec):
def __init__(self, config: TransformerConfig, submodules: MLPSubmodules, gate: bool):
config = deepcopy(config)
assert config.add_bias_linear == False, "bias is not supported in the shared experts, "
"please set '--disable-bias-linear' instead."

config.ffn_hidden_size = config.moe_shared_expert_intermediate_size
super().__init__(config=config, submodules=spec.submodules)
super().__init__(config=config, submodules=submodules)

self.use_shared_expert_gate = spec.params.get("gate", False)
self.use_shared_expert_gate = gate
if self.use_shared_expert_gate:
# TODO: Add support for GPU initialization, which requires updating the golden values.
self.gate_weight = torch.nn.Parameter(torch.empty((1, self.config.hidden_size)))
Expand Down
Loading

0 comments on commit aa2a45d

Please sign in to comment.