Skip to content

Commit

Permalink
Merge branch 'denliu/fp8_moe' into 'main'
Browse files Browse the repository at this point in the history
FP8 support for MoE with conservative recipe

Closes #43

See merge request ADLR/megatron-lm!1089
  • Loading branch information
ko3n1g committed Sep 6, 2024
2 parents a2b6ee4 + 8f331e8 commit cc16182
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 42 deletions.
49 changes: 41 additions & 8 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from megatron.core.parallel_state import (
get_context_parallel_global_ranks,
get_context_parallel_group,
get_tensor_and_expert_parallel_world_size,
get_tensor_model_parallel_group,
)
from megatron.core.tensor_parallel import get_cuda_rng_tracker, get_expert_parallel_rng_tracker_name
Expand Down Expand Up @@ -111,6 +112,7 @@ def __init__(
skip_bias_add: bool,
skip_weight_param_allocation: bool,
tp_comm_buffer_name: str = None,
is_expert: bool = False,
):
self.config = config

Expand Down Expand Up @@ -143,24 +145,56 @@ def __init__(
if hasattr(self.config, "tp_comm_overlap_rs")
else self.config.tp_comm_split_rs or self.config.tp_comm_atomic_rs
)
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_overlap_ag"] = False
extra_kwargs["ub_overlap_rs"] = False
else:
extra_kwargs["ub_split_ag"] = self.config.tp_comm_split_ag
extra_kwargs["ub_atomic_gemm_ag"] = self.config.tp_comm_atomic_ag
extra_kwargs["ub_split_rs"] = self.config.tp_comm_split_rs
extra_kwargs["ub_atomic_gemm_rs"] = self.config.tp_comm_atomic_rs
# Disable ub overlap for experts.
if is_expert:
extra_kwargs["ub_split_ag"] = False
extra_kwargs["ub_atomic_gemm_ag"] = False
extra_kwargs["ub_split_rs"] = False
extra_kwargs["ub_atomic_gemm_rs"] = False
if _te_version > packaging.version.Version("1.0.0"):
assert (
tp_comm_buffer_name is not None
), "Buffer name should be set to configure communication overlap settings"
extra_kwargs["ub_name"] = tp_comm_buffer_name

self.expert_parallel = self.config.expert_model_parallel_size > 1
if is_expert and self.expert_parallel:
rng_tracker_name = get_expert_parallel_rng_tracker_name()
else:
rng_tracker_name = None
if _te_version >= packaging.version.Version("1.7.0.dev"):
extra_kwargs["rng_tracker_name"] = rng_tracker_name

# Disable communications in TE when using SP or EP by making TE agnostic of model parallel.
tp_size = self.config.tensor_model_parallel_size
tp_group = get_tensor_model_parallel_group(check_initialized=False)
if is_expert and (self.config.sequence_parallel or self.expert_parallel):
if self.config.moe_extended_tp:
tp_size = get_tensor_and_expert_parallel_world_size()
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
parallel_mode = None
tp_size = 1
tp_group = None

super().__init__(
in_features=input_size,
out_features=output_size,
sequence_parallel=self.config.sequence_parallel,
fuse_wgrad_accumulation=self.config.gradient_accumulation_fusion,
tp_group=get_tensor_model_parallel_group(check_initialized=False),
tp_size=self.config.tensor_model_parallel_size,
tp_group=tp_group,
tp_size=tp_size,
get_rng_state_tracker=(
get_cuda_rng_tracker if get_cuda_rng_tracker().is_initialized() else None
),
Expand All @@ -171,6 +205,9 @@ def __init__(
**extra_kwargs,
)

for param in self.parameters():
setattr(param, 'allreduce', not (is_expert and self.expert_parallel))

def forward(self, x):
"""Forward."""
_is_first_microbatch = (
Expand Down Expand Up @@ -337,9 +374,6 @@ def __init__(
if gather_output:
raise ValueError('Transformer Engine linear layers do not support gather_output = True')

if is_expert:
raise ValueError('Transformer Engine linear layers do not yet support MoE')

super().__init__(
input_size=input_size,
output_size=output_size,
Expand All @@ -348,6 +382,7 @@ def __init__(
init_method=condition_init_method(config, init_method),
bias=bias,
skip_bias_add=skip_bias_add,
is_expert=is_expert,
skip_weight_param_allocation=skip_weight_param_allocation,
tp_comm_buffer_name=tp_comm_buffer_name,
)
Expand Down Expand Up @@ -384,9 +419,6 @@ def __init__(
"Transformer Engine linear layers do not support input_is_parallel = False"
)

if is_expert:
raise ValueError('Transformer Engine linear layers do not yet support MoE')

super().__init__(
input_size=input_size,
output_size=output_size,
Expand All @@ -396,6 +428,7 @@ def __init__(
bias=bias,
skip_bias_add=skip_bias_add,
skip_weight_param_allocation=False, # We don't currently use this for row parallel layers # pylint: disable=line-too-long
is_expert=is_expert,
tp_comm_buffer_name=tp_comm_buffer_name,
)

Expand Down
9 changes: 8 additions & 1 deletion megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
try:
from megatron.core.extensions.transformer_engine import (
TEColumnParallelGroupedLinear,
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TENorm,
Expand Down Expand Up @@ -47,6 +48,7 @@ def get_gpt_layer_with_transformer_engine_spec(
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
qk_layernorm: Optional[bool] = False,
fp8: Optional[str] = None,
) -> ModuleSpec:
"""Use this spec to use lower-level Transformer Engine modules (required for fp8 training).
Expand All @@ -55,12 +57,13 @@ def get_gpt_layer_with_transformer_engine_spec(
num_experts (int, optional): Number of experts. Defaults to None.
moe_grouped_gemm (bool, optional): To use Grouped GEMM. Defaults to False.
qk_layernorm (bool, optional): To use layernorm for queries/keys. Defaults to False.
fp8 (str, optional): Flag to decide the linear layer spec for MoE. Defaults to None.
Returns:
ModuleSpec: Module specification with TE modules
"""
mlp = _get_mlp_module_spec(
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm
use_te=True, num_experts=num_experts, moe_grouped_gemm=moe_grouped_gemm, fp8=fp8
)
return ModuleSpec(
module=TransformerLayer,
Expand Down Expand Up @@ -136,6 +139,7 @@ def _get_mlp_module_spec(
use_te: Optional[bool] = True,
num_experts: Optional[int] = None,
moe_grouped_gemm: Optional[bool] = False,
fp8: Optional[str] = None,
) -> ModuleSpec:
"""Helper function to get module spec for MLP/MoE"""
if num_experts is None:
Expand All @@ -152,6 +156,9 @@ def _get_mlp_module_spec(
if use_te and moe_grouped_gemm:
linear_fc1 = TEColumnParallelGroupedLinear
linear_fc2 = TERowParallelGroupedLinear
elif use_te and fp8:
linear_fc1 = TEColumnParallelLinear
linear_fc2 = TERowParallelLinear
else:
linear_fc1 = ColumnParallelLinear
linear_fc2 = RowParallelLinear
Expand Down
85 changes: 53 additions & 32 deletions megatron/core/transformer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from copy import deepcopy
from functools import partial
from math import ceil
from typing import Optional, Tuple

import torch
Expand Down Expand Up @@ -34,10 +35,9 @@


class GroupedMLP(MegatronModule):
"""An efficient implementation of the Experts layer using CUTLASS GroupedGEMM.
"""An efficient implementation of the Experts layer using GroupedGEMM.
This class is designed to execute multiple experts in parallel, thereby maximizing
computational efficiency.
Executes multiple experts in parallel to maximize computational efficiency.
"""

def __init__(self, num_local_experts: int, config: TransformerConfig):
Expand All @@ -47,8 +47,7 @@ def __init__(self, num_local_experts: int, config: TransformerConfig):
gg.assert_grouped_gemm_is_available()
assert (
config.add_bias_linear == False
), "bias in the expert layer is not supported in Grouped GEMM yet, please set \
'--disable-bias-linear' instead."
), "bias not supported in Grouped GEMM yet, please set '--disable-bias-linear' instead."

self.expert_parallel = config.expert_model_parallel_size > 1
if self.config.gated_linear_unit:
Expand Down Expand Up @@ -163,7 +162,7 @@ def remove_extra_states_check(self, incompatible_keys):

self.register_load_state_dict_post_hook(remove_extra_states_check)

def forward(self, permuted_local_hidden_states, tokens_per_expert):
def forward(self, permuted_local_hidden_states: torch.Tensor, tokens_per_expert: torch.Tensor):
"""Forward step of the GroupedMLP."""
if permuted_local_hidden_states.nelement() != 0:
# Reshape the weights for the grouped GEMMs.
Expand All @@ -181,8 +180,7 @@ def forward(self, permuted_local_hidden_states, tokens_per_expert):
# No token is allocated for local experts.
assert torch.count_nonzero(tokens_per_expert) == 0

# Make sure parameters still have gradients when no tokens are routed to this set of
# experts.
# Make sure params of experts still have gradients even given zero tokens.
w1 = self.weight1.view(self.config.hidden_size, -1)
w2 = self.weight2.view(-1, self.config.hidden_size)
h = torch.matmul(permuted_local_hidden_states, w1)
Expand Down Expand Up @@ -347,8 +345,7 @@ def sh_ten_merge_fn(sub_state_dict, tp_axis: int, with_glu: bool):
class TEGroupedMLP(MegatronModule):
"""An efficient implementation of the Experts layer using TE's GroupedLinear.
This class is designed to execute multiple experts in parallel, thereby maximizing
computational efficiency.
Executes multiple experts in parallel to maximize computational efficiency.
"""

def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLPSubmodules):
Expand All @@ -357,8 +354,7 @@ def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLP
self.num_local_experts = num_local_experts
self.input_size = self.config.hidden_size

# If this is a gated linear unit we double the output width, see
# https://arxiv.org/pdf/2002.05202.pdf
# Double the output width with gated linear unit, see https://arxiv.org/pdf/2002.05202.pdf
ffn_hidden_size = self.config.ffn_hidden_size
if self.config.gated_linear_unit:
ffn_hidden_size *= 2
Expand Down Expand Up @@ -505,29 +501,54 @@ def __init__(self, num_local_experts, config: TransformerConfig, submodules: MLP
expert = MLP(self.config, submodules, is_expert=True)
self.local_experts.append(expert)

def forward(self, permuted_local_hidden_states, tokens_per_expert):
def _pad_tensor_for_fp8(self, hidden):
"""Padding tensor shape to multiples of 16."""
actual_num_tokens = hidden.shape[0]
divisor = 16
padded_num_tokens = ceil(actual_num_tokens / divisor) * divisor - actual_num_tokens
if padded_num_tokens > 0:
pad_tensor = torch.zeros(
padded_num_tokens, hidden.shape[1], dtype=hidden.dtype, device=hidden.device
)
hidden = torch.cat((hidden, pad_tensor), dim=0)
return hidden

def forward(self, permuted_local_hidden_states: torch.Tensor, tokens_per_expert: torch.Tensor):
"""Forward step of the SequentialMLP."""
output_local = torch.zeros_like(permuted_local_hidden_states)
output_bias_local = None
if self.add_bias:
output_bias_local = torch.zeros_like(permuted_local_hidden_states)

cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
# Insert zero at the beginning for offset index's convenience
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
for expert_num, expert in enumerate(self.local_experts):
start = cumsum_num_tokens[expert_num]
end = cumsum_num_tokens[expert_num + 1]
hidden = permuted_local_hidden_states[start:end]
output, output_bias = expert(hidden)

output_local[start:end] = output
if self.num_local_experts == 1:
if self.config.fp8:
hidden = self._pad_tensor_for_fp8(permuted_local_hidden_states)
output, output_bias = self.local_experts[0](hidden)
output = output[: permuted_local_hidden_states.shape[0]]
else:
output, output_bias = self.local_experts[0](permuted_local_hidden_states)

return output, output_bias
else:
tokens_per_expert = tokens_per_expert.tolist()
tokens_list = torch.split(permuted_local_hidden_states, tokens_per_expert)

output_local_list = []
output_bias_list = []

for expert, tokens in zip(self.local_experts, tokens_list):
if self.config.fp8:
hidden = self._pad_tensor_for_fp8(tokens)
output, output_bias = expert(hidden)
output = output[: tokens.shape[0]]
else:
output, output_bias = expert(tokens)
output_local_list.append(output)
if self.add_bias:
output_bias_list.append(output_bias.expand_as(output))

output_local = torch.cat(output_local_list, dim=0)
if self.add_bias:
output_bias = output_bias.expand_as(output)
output_bias_local[start:end, :] = output_bias
output_bias_local = torch.cat(output_bias_list, dim=0)
else:
output_bias_local = None

return output_local, output_bias_local
return output_local, output_bias_local

def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None):
"""Maps local expert to global experts."""
Expand Down
14 changes: 14 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from dataclasses import dataclass
from importlib.metadata import version
from typing import Callable, Optional, Tuple

import torch.nn.functional as F
from pkg_resources import packaging

from ..model_parallel_config import ModelParallelConfig
from ..utils import init_method_normal, scaled_init_method_normal
Expand Down Expand Up @@ -475,3 +477,15 @@ def __post_init__(self):
f'ffn_hidden_size: {self.ffn_hidden_size} must be divisible by '
f'extended_tp_size {extended_tp_size}'
)

if self.num_moe_experts and self.fp8:
# TE version below 1.7.0 will raise Error when handle zeros tokens for expert
te_version = packaging.version.Version(version("transformer-engine"))
if te_version < packaging.version.Version("1.7.0.dev0"):
raise ValueError(
"Only transformer-engine>=1.7.0 supports MoE FP8 training, "
f"but your version is {te_version}."
)

if self.moe_grouped_gemm:
raise ValueError("Grouped GEMM of MoE not support fp8 for now.")
2 changes: 1 addition & 1 deletion pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
transformer_layer_spec = import_module(args.spec)
else:
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm, args.fp8)
else:
transformer_layer_spec = get_gpt_layer_local_spec(args.num_experts, args.moe_grouped_gemm, args.qk_layernorm)

Expand Down
Loading

0 comments on commit cc16182

Please sign in to comment.