Skip to content

Commit

Permalink
Merge branch 'hongxiaob/shared_expert' into 'main'
Browse files Browse the repository at this point in the history
MoE Shared Expert support

Closes #134

See merge request ADLR/megatron-lm!1699
  • Loading branch information
ko3n1g committed Sep 11, 2024
2 parents fe1640a + 1fa9464 commit fec11a7
Show file tree
Hide file tree
Showing 14 changed files with 534 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .gitlab/stages/01.tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ unit_tests:
parallel:
matrix:
- TAG: latest
- TAG: 655a663df2e9c3d8991e676e0163a5822da249a7
- TAG: 0bb840767d0643c2d0df7192d754ec7db3a18412
tags: [8xL40S]
variables:
GIT_STRATEGY: clone
Expand Down
21 changes: 16 additions & 5 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.identity_op import IdentityOp
from megatron.core.transformer.mlp import MLP, MLPSubmodules
from megatron.core.transformer.moe.moe_layer import MoELayer
from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules

Expand Down Expand Up @@ -167,9 +168,19 @@ def _get_mlp_module_spec(

return ModuleSpec(
module=MoELayer,
submodules=(
MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2)
if not moe_grouped_gemm or use_te_grouped_gemm
else None
submodules=MoESubmodules(
experts=(
MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2)
if not moe_grouped_gemm or use_te_grouped_gemm
else None
),
shared_experts=ModuleSpec(
module=SharedExpertMLP,
params={"gate": False},
submodules=MLPSubmodules(
linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear,
linear_fc2=TERowParallelLinear if use_te else RowParallelLinear,
),
),
),
)
15 changes: 14 additions & 1 deletion megatron/core/transformer/moe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ Megatron-Core offers rich parallelism mappings, combining Expert Parallelism wit
| --moe-token-drop-policy | The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped. |
| --moe-layer-recompute | Enable activation checkpointing for moe_layer, should be used when memory is not sufficient. |
| --moe-extended-tp | (Experimental) Alternative parallelization strategy for expert parallelism. Instead of distributing experts across *expert_model_parallel_size*, each expert is sharded along extendended tensor parallel domain (tensor_model_paralle_size * expert_model_parallel_size). It avoids the load balancing problem with MOE training. Only available with `--moe-token-dispatcher-type allgather`. |
| --moe-shared-expert-intermediate-size | Set shared expert total ffn hidden size. It should be equal to `num_shared_experts * ffn_size_of_each_shared_expert` if there are multiple shared experts. None means no shared expert. |
| --moe-shared-expert-overlap | (Experimental, may changed) If this is set, the communications/computations in the shared experts and the dispatcher will overlap (The `alltoall` dispatcher is needed.) Otherwise, the shared expert runs after the routed experts. |
| --moe-use-upcycling | Load the dense model checkpoint, convert it into an MoE model at runtime and start training. The converted model will be saved to the path specified by `--save` before training begins. Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.|


Expand Down Expand Up @@ -118,8 +120,19 @@ Usage
- `--use-dist-ckpt` The main argument, it will attempt to save and load using distributed checkpointing.
- `--auto-detect-ckpt-format` With this, it can load both distributed checkpointing and legacy checkpointing.

### Upcycling
### Shared Experts
MCore v0.9 introduced the shared expert feature. We can enable this feature by setting suitable `--moe-shared-expert-intermediate-size`.

The parallelism patterns of the shared experts follow the settings of the dense part, i.e., the attention module. The shared experts are not distributed but replicated in EP ranks.

We also have an experimental feature that tries to overlap the communications and computations in the shared experts and the dispatcher.
We can set `--moe-shared-expert-overlap` and use `alltoall` dispatcher to enable it.
The overlapping relies on the envirionment setting `CUDA_DEVICE_MAX_CONNECTIONS=1`.
The `AllGather` and `ReduceScatter` communications in the shared experts are overlapped with `permute`/`unpermute` in the dispatcher.
The `MLP` computation part in the shared experts are overlapped with the `AlltoAll` communications in the dispatcher.
Both the forward and the backward pass can overlap. But to get the overlapping in the backward pass, the PyTorch version should `>= 2.2.0`.

### Upcycling
Use `--moe-use-upcycling` to enable the upcycling feature, which will load the dense model from the directory specified by `--load`, convert it into an MoE model at runtime and start training. The converted model will be saved to the path specified by `--save` before training begins. Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.

The MoE model structure is defined through script arguments. All MoE-related arguments (such as `--num-experts`) can be customized; however, other model structure arguments must be consistent with those of the dense model.
Expand Down
46 changes: 41 additions & 5 deletions megatron/core/transformer/moe/moe_layer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Union

import torch

Expand All @@ -10,13 +12,23 @@
from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP
from megatron.core.transformer.moe.legacy_a2a_token_dispatcher import MoEAlltoAllSEQTokenDispatcher
from megatron.core.transformer.moe.router import TopKRouter
from megatron.core.transformer.moe.shared_experts import SharedExpertMLP
from megatron.core.transformer.moe.token_dispatcher import (
MoEAllGatherTokenDispatcher,
MoEAlltoAllTokenDispatcher,
)
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig


@dataclass
class MoESubmodules:
"""MoE Layer Submodule spec"""

experts: Union[ModuleSpec, type] = None
shared_experts: Union[ModuleSpec, type] = None


class BaseMoELayer(MegatronModule, ABC):
"""Base class for a mixture of experts layer.
Expand All @@ -40,12 +52,16 @@ def __init__(self, config: TransformerConfig, layer_number: int = None):
parallel_state.get_expert_model_parallel_rank() * self.num_local_experts
)

self.use_shared_expert = self.config.moe_shared_expert_intermediate_size is not None
self.shared_expert_overlap = self.config.moe_shared_expert_overlap

self.local_expert_indices = [
local_expert_indices_offset + i for i in range(self.num_local_experts)
]
assert all(map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices))
self.router = None
self.experts = None
self.shared_experts = None
self.token_dispatcher = None
self.layer_number = layer_number

Expand All @@ -72,15 +88,26 @@ def __init__(
):
self.submodules = submodules
super(MoELayer, self).__init__(config=config, layer_number=layer_number)
self.moe_layer_recompute = config.moe_layer_recompute

# Initialize router
self.router = TopKRouter(config=self.config)

# Initialize experts
if self.config.moe_grouped_gemm:
if isinstance(self.submodules, MLPSubmodules):
self.experts = TEGroupedMLP(self.num_local_experts, self.config, self.submodules)
if isinstance(self.submodules.experts, MLPSubmodules):
self.experts = TEGroupedMLP(
self.num_local_experts, self.config, self.submodules.experts
)
else:
self.experts = GroupedMLP(self.num_local_experts, self.config)
else:
assert isinstance(self.submodules, MLPSubmodules)
self.experts = SequentialMLP(self.num_local_experts, self.config, self.submodules)
assert isinstance(self.submodules.experts, MLPSubmodules)
self.experts = SequentialMLP(
self.num_local_experts, self.config, self.submodules.experts
)

# Initialize token dispatcher
if config.moe_token_dispatcher_type == "allgather":
self.token_dispatcher = MoEAllGatherTokenDispatcher(
self.num_local_experts, self.local_expert_indices, config=self.config
Expand All @@ -97,7 +124,12 @@ def __init__(
raise ValueError(
f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}"
)
self.moe_layer_recompute = config.moe_layer_recompute

# Initialize shared experts
if self.use_shared_expert:
self.shared_experts = SharedExpertMLP(self.config, self.submodules.shared_experts)
if self.shared_expert_overlap:
self.token_dispatcher.set_shared_experts(self.shared_experts)

def forward(self, hidden_states: torch.Tensor):
if (
Expand All @@ -118,6 +150,10 @@ def custom_forward(hidden_states):
)
expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert)
output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias)
if self.use_shared_expert and not self.shared_expert_overlap:
# if shared_expert_overlap is True, the expert calculation happens in
# the token_dispatcher to overlap communications and computations
output += self.shared_experts(hidden_states)
return output, mlp_bias

if self.moe_layer_recompute:
Expand Down
Loading

0 comments on commit fec11a7

Please sign in to comment.