diff --git a/.gitlab/stages/01.tests.yml b/.gitlab/stages/01.tests.yml index 094c5fd613..89cd9cfea3 100644 --- a/.gitlab/stages/01.tests.yml +++ b/.gitlab/stages/01.tests.yml @@ -90,7 +90,7 @@ unit_tests: parallel: matrix: - TAG: latest - - TAG: 655a663df2e9c3d8991e676e0163a5822da249a7 + - TAG: 0bb840767d0643c2d0df7192d754ec7db3a18412 tags: [8xL40S] variables: GIT_STRATEGY: clone diff --git a/megatron/core/models/gpt/gpt_layer_specs.py b/megatron/core/models/gpt/gpt_layer_specs.py index 892ed92259..d469f5e4ce 100755 --- a/megatron/core/models/gpt/gpt_layer_specs.py +++ b/megatron/core/models/gpt/gpt_layer_specs.py @@ -9,7 +9,8 @@ from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules -from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.moe.moe_layer import MoELayer, MoESubmodules +from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules @@ -167,9 +168,19 @@ def _get_mlp_module_spec( return ModuleSpec( module=MoELayer, - submodules=( - MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2) - if not moe_grouped_gemm or use_te_grouped_gemm - else None + submodules=MoESubmodules( + experts=( + MLPSubmodules(linear_fc1=linear_fc1, linear_fc2=linear_fc2) + if not moe_grouped_gemm or use_te_grouped_gemm + else None + ), + shared_experts=ModuleSpec( + module=SharedExpertMLP, + params={"gate": False}, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, + ), + ), ), ) diff --git a/megatron/core/transformer/moe/README.md b/megatron/core/transformer/moe/README.md index 10f43b1792..a8fa73bcaa 100644 --- a/megatron/core/transformer/moe/README.md +++ b/megatron/core/transformer/moe/README.md @@ -61,6 +61,8 @@ Megatron-Core offers rich parallelism mappings, combining Expert Parallelism wit | --moe-token-drop-policy | The policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped. | | --moe-layer-recompute | Enable activation checkpointing for moe_layer, should be used when memory is not sufficient. | | --moe-extended-tp | (Experimental) Alternative parallelization strategy for expert parallelism. Instead of distributing experts across *expert_model_parallel_size*, each expert is sharded along extendended tensor parallel domain (tensor_model_paralle_size * expert_model_parallel_size). It avoids the load balancing problem with MOE training. Only available with `--moe-token-dispatcher-type allgather`. | +| --moe-shared-expert-intermediate-size | Set shared expert total ffn hidden size. It should be equal to `num_shared_experts * ffn_size_of_each_shared_expert` if there are multiple shared experts. None means no shared expert. | +| --moe-shared-expert-overlap | (Experimental, may changed) If this is set, the communications/computations in the shared experts and the dispatcher will overlap (The `alltoall` dispatcher is needed.) Otherwise, the shared expert runs after the routed experts. | | --moe-use-upcycling | Load the dense model checkpoint, convert it into an MoE model at runtime and start training. The converted model will be saved to the path specified by `--save` before training begins. Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model.| @@ -118,8 +120,19 @@ Usage - `--use-dist-ckpt` The main argument, it will attempt to save and load using distributed checkpointing. - `--auto-detect-ckpt-format` With this, it can load both distributed checkpointing and legacy checkpointing. -### Upcycling +### Shared Experts +MCore v0.9 introduced the shared expert feature. We can enable this feature by setting suitable `--moe-shared-expert-intermediate-size`. + +The parallelism patterns of the shared experts follow the settings of the dense part, i.e., the attention module. The shared experts are not distributed but replicated in EP ranks. +We also have an experimental feature that tries to overlap the communications and computations in the shared experts and the dispatcher. +We can set `--moe-shared-expert-overlap` and use `alltoall` dispatcher to enable it. +The overlapping relies on the envirionment setting `CUDA_DEVICE_MAX_CONNECTIONS=1`. +The `AllGather` and `ReduceScatter` communications in the shared experts are overlapped with `permute`/`unpermute` in the dispatcher. +The `MLP` computation part in the shared experts are overlapped with the `AlltoAll` communications in the dispatcher. +Both the forward and the backward pass can overlap. But to get the overlapping in the backward pass, the PyTorch version should `>= 2.2.0`. + +### Upcycling Use `--moe-use-upcycling` to enable the upcycling feature, which will load the dense model from the directory specified by `--load`, convert it into an MoE model at runtime and start training. The converted model will be saved to the path specified by `--save` before training begins. Upcycling is implemented on the top of distributed checkpointing, so it supports parallel modes different from the dense model. The MoE model structure is defined through script arguments. All MoE-related arguments (such as `--num-experts`) can be customized; however, other model structure arguments must be consistent with those of the dense model. diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index dea0bf658e..8b393abc77 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -1,6 +1,8 @@ # Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Union import torch @@ -10,13 +12,23 @@ from megatron.core.transformer.moe.experts import GroupedMLP, SequentialMLP, TEGroupedMLP from megatron.core.transformer.moe.legacy_a2a_token_dispatcher import MoEAlltoAllSEQTokenDispatcher from megatron.core.transformer.moe.router import TopKRouter +from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.moe.token_dispatcher import ( MoEAllGatherTokenDispatcher, MoEAlltoAllTokenDispatcher, ) +from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig +@dataclass +class MoESubmodules: + """MoE Layer Submodule spec""" + + experts: Union[ModuleSpec, type] = None + shared_experts: Union[ModuleSpec, type] = None + + class BaseMoELayer(MegatronModule, ABC): """Base class for a mixture of experts layer. @@ -40,12 +52,16 @@ def __init__(self, config: TransformerConfig, layer_number: int = None): parallel_state.get_expert_model_parallel_rank() * self.num_local_experts ) + self.use_shared_expert = self.config.moe_shared_expert_intermediate_size is not None + self.shared_expert_overlap = self.config.moe_shared_expert_overlap + self.local_expert_indices = [ local_expert_indices_offset + i for i in range(self.num_local_experts) ] assert all(map(lambda x: x < self.config.num_moe_experts, self.local_expert_indices)) self.router = None self.experts = None + self.shared_experts = None self.token_dispatcher = None self.layer_number = layer_number @@ -72,15 +88,26 @@ def __init__( ): self.submodules = submodules super(MoELayer, self).__init__(config=config, layer_number=layer_number) + self.moe_layer_recompute = config.moe_layer_recompute + + # Initialize router self.router = TopKRouter(config=self.config) + + # Initialize experts if self.config.moe_grouped_gemm: - if isinstance(self.submodules, MLPSubmodules): - self.experts = TEGroupedMLP(self.num_local_experts, self.config, self.submodules) + if isinstance(self.submodules.experts, MLPSubmodules): + self.experts = TEGroupedMLP( + self.num_local_experts, self.config, self.submodules.experts + ) else: self.experts = GroupedMLP(self.num_local_experts, self.config) else: - assert isinstance(self.submodules, MLPSubmodules) - self.experts = SequentialMLP(self.num_local_experts, self.config, self.submodules) + assert isinstance(self.submodules.experts, MLPSubmodules) + self.experts = SequentialMLP( + self.num_local_experts, self.config, self.submodules.experts + ) + + # Initialize token dispatcher if config.moe_token_dispatcher_type == "allgather": self.token_dispatcher = MoEAllGatherTokenDispatcher( self.num_local_experts, self.local_expert_indices, config=self.config @@ -97,7 +124,12 @@ def __init__( raise ValueError( f"Unsupported token dispatcher type: {config.moe_token_dispatcher_type}" ) - self.moe_layer_recompute = config.moe_layer_recompute + + # Initialize shared experts + if self.use_shared_expert: + self.shared_experts = SharedExpertMLP(self.config, self.submodules.shared_experts) + if self.shared_expert_overlap: + self.token_dispatcher.set_shared_experts(self.shared_experts) def forward(self, hidden_states: torch.Tensor): if ( @@ -118,6 +150,10 @@ def custom_forward(hidden_states): ) expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert) output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias) + if self.use_shared_expert and not self.shared_expert_overlap: + # if shared_expert_overlap is True, the expert calculation happens in + # the token_dispatcher to overlap communications and computations + output += self.shared_experts(hidden_states) return output, mlp_bias if self.moe_layer_recompute: diff --git a/megatron/core/transformer/moe/shared_experts.py b/megatron/core/transformer/moe/shared_experts.py new file mode 100644 index 0000000000..c2d9c188e3 --- /dev/null +++ b/megatron/core/transformer/moe/shared_experts.py @@ -0,0 +1,262 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import warnings +from copy import deepcopy +from typing import Optional + +import torch +import torch.nn.functional as F + +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl +from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl +from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl +from megatron.core.tensor_parallel.mappings import ( + copy_to_tensor_model_parallel_region, + gather_from_sequence_parallel_region, + reduce_from_tensor_model_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) +from megatron.core.tensor_parallel.random import ( + get_cuda_rng_tracker, + get_data_parallel_rng_tracker_name, +) +from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_config import TransformerConfig +from megatron.core.utils import make_sharded_tensor_for_checkpoint + + +class SharedExpertMLP(MLP): + """ + MLP layer for Shared Experts. + """ + + # This stream is used when '--moe-shared-expert-overlap' is set. + # The shared experts are scheduled into this stream to be overlapped with the dispatcher. + stream = None + + def __init__(self, config: TransformerConfig, spec: ModuleSpec): + config = deepcopy(config) + assert config.add_bias_linear == False, "bias is not supported in the shared experts, " + "please set '--disable-bias-linear' instead." + + config.ffn_hidden_size = config.moe_shared_expert_intermediate_size + super().__init__(config=config, submodules=spec.submodules) + + self.use_shared_expert_gate = spec.params.get("gate", False) + if self.use_shared_expert_gate: + self.gate_weight = torch.nn.Parameter(torch.empty((1, self.config.hidden_size))) + if config.perform_initialization: + if get_cuda_rng_tracker().is_initialized(): + with get_cuda_rng_tracker().fork(get_data_parallel_rng_tracker_name()): + config.init_method(self.gate_weight) + else: + config.init_method(self.gate_weight) + self.gate_weight.data = self.gate_weight.data.to(dtype=config.params_dtype) + setattr(self.gate_weight, 'sequence_parallel', self.config.sequence_parallel) + else: + self.gate_weight = None + + if self.config.moe_shared_expert_overlap: + # disable TP related AG/RS communications in the linear module + for linear in [self.linear_fc1, self.linear_fc2]: + if hasattr(linear, 'parallel_mode'): + # TELinear + linear.parallel_mode = None + else: + # MCore legacy Linear + linear.explicit_expert_comm = True + + # The overlapped version is splitted into some separated functions and is put inside + # the token dispatcher. These functions should be called in this order and no one can + # be skipped: + # pre_forward_comm(input) + # linear_fc1_forward_and_act() + # linear_fc2_forward() + # post_forward_comm() + # output = get_output() + # + # We use cached intermediate results to avoid messy arg passing in the dispatcher. + self.cached_fc1_input = None + self.cached_fc2_input = None + self.cached_fc2_output = None + self.cached_output = None + self.gate_score = None + + if self.stream is None: + self.stream = torch.cuda.Stream() + + def forward(self, hidden_states): + """Forward function""" + output, _ = super().forward(hidden_states) + if self.use_shared_expert_gate: + logits = torch.nn.functional.linear(hidden_states, self.gate_weight) + gate_score = torch.nn.functional.sigmoid(logits) + output = output * gate_score + return output + + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None + ) -> ShardedStateDict: + """Gets sharded state dict.""" + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + if self.use_shared_expert_gate: + name = 'gate_weight' + state_dict = self.state_dict(prefix='', keep_vars=True) + sub_sd = { + f'{prefix}{name}': make_sharded_tensor_for_checkpoint( + state_dict[name], f'{prefix}{name}', prepend_offsets=sharded_offsets + ) + } + sharded_state_dict.update(sub_sd) + return sharded_state_dict + + def pre_forward_comm(self, input): + """ + All Gather for SP before forward. + This function is used to overlap shared experts with the dispatcher. + It is only useful when --moe-shared-expert-overlap is set and may be changed. + """ + assert self.config.moe_shared_expert_overlap + assert self.cached_output is None + self.stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self.stream): + if self.use_shared_expert_gate: + logits = torch.nn.functional.linear(input, self.gate_weight) + self.gate_score = torch.nn.functional.sigmoid(logits) + if self.config.sequence_parallel: + self.cached_fc1_input = gather_from_sequence_parallel_region( + input, tensor_parallel_output_grad=True + ) + else: + self.cached_fc1_input = copy_to_tensor_model_parallel_region(input) + set_tensor_grad_fn_sequence_sr(self.cached_fc1_input, torch.iinfo(torch.int).max) + + def linear_fc1_forward_and_act(self, overlapped_comm_output=None): + """ + Do Linear FC1 and activation function forward. + This function is used to overlap shared experts with the dispatcher. + It is only useful when --moe-shared-expert-overlap is set and may be changed. + """ + assert self.config.moe_shared_expert_overlap + assert self.cached_fc1_input is not None + if overlapped_comm_output is not None: + set_tensor_grad_fn_sequence_sr(overlapped_comm_output, torch.iinfo(torch.int).max) + with torch.cuda.stream(self.stream): + # [s, b, 4 * h/p] + intermediate_parallel, bias_parallel = self.linear_fc1(self.cached_fc1_input) + self.cached_fc1_input = None + + if self.config.bias_activation_fusion: + if self.activation_func == F.gelu: + if self.config.gated_linear_unit: + intermediate_parallel = bias_geglu_impl( + intermediate_parallel, bias_parallel + ) + else: + assert self.config.add_bias_linear is True + intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) + elif self.activation_func == F.silu and self.config.gated_linear_unit: + intermediate_parallel = bias_swiglu_impl( + intermediate_parallel, + bias_parallel, + self.config.activation_func_fp8_input_store, + ) + else: + raise ValueError("Only support fusion of gelu and swiglu") + else: + if bias_parallel is not None: + intermediate_parallel = intermediate_parallel + bias_parallel + if self.config.gated_linear_unit: + + def glu(x): + x = torch.chunk(x, 2, dim=-1) + return self.config.activation_func(x[0]) * x[1] + + intermediate_parallel = glu(intermediate_parallel) + else: + intermediate_parallel = self.activation_func(intermediate_parallel) + + self.cached_fc2_input = intermediate_parallel + + def linear_fc2_forward(self, overlapped_comm_output=None): + """ + Do Linear FC2 forward. + This function is used to overlap shared experts with the dispatcher. + It is only useful when --moe-shared-expert-overlap is set and may be changed. + """ + assert self.config.moe_shared_expert_overlap + assert self.cached_fc2_input is not None + if overlapped_comm_output is not None: + set_tensor_grad_fn_sequence_sr(overlapped_comm_output, torch.iinfo(torch.int).max) + with torch.cuda.stream(self.stream): + # [s, b, h] + self.cached_fc2_output, _ = self.linear_fc2(self.cached_fc2_input) + self.cached_fc2_input = None + + def post_forward_comm(self): + """ + Reduce scatter for SP after forward. + This function is used to overlap shared experts with the dispatcher. + It is only useful when --moe-shared-expert-overlap is set and may be changed. + """ + assert self.config.moe_shared_expert_overlap + assert self.cached_fc2_output is not None + with torch.cuda.stream(self.stream): + if self.config.sequence_parallel: + self.cached_output = reduce_scatter_to_sequence_parallel_region( + self.cached_fc2_output + ) + else: + self.cached_output = reduce_from_tensor_model_parallel_region( + self.cached_fc2_output + ) + self.cached_fc2_output = None + set_tensor_grad_fn_sequence_sr(self.cached_output, torch.iinfo(torch.int).max) + + def get_output(self): + """ + Gets the module forward output. + This function is used to overlap shared experts with the dispatcher. + It is only useful when --moe-shared-expert-overlap is set and may be changed. + """ + assert self.config.moe_shared_expert_overlap + assert self.cached_output is not None + with torch.cuda.stream(self.stream): + if self.use_shared_expert_gate: + assert self.gate_score is not None + output = self.cached_output * self.gate_score + self.gate_score = None + else: + output = self.cached_output + self.cached_output = None + torch.cuda.current_stream().wait_stream(self.stream) + return output + + +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) +TORCH_LAST = torch.__version__.split(".")[2] + + +def set_tensor_grad_fn_sequence_sr(tensor, value): + """ + Set sequence_sr for the grad_fn of a tensor to control the backward order. + For older PyTorch version, do nothing (backward order is not changed). + The bigger the value is, the earlier the grad_fn is scheduled. + """ + if ( + (TORCH_MAJOR > 2) + or (TORCH_MAJOR == 2 and TORCH_MINOR > 2) + or (TORCH_MAJOR == 2 and TORCH_MINOR == 2 and '+' not in TORCH_LAST) + ): + # In NVIDIA PyTorch container 24.01, the PyTorch version is 2.2.0a0+81ea7a4, + # which does not contian the set_sequence_nr commit. + if tensor is not None and tensor.grad_fn is not None: + tensor.grad_fn._set_sequence_nr(value) + else: + warnings.warn( + "WARNING : PyTorch is too old to set sequence_sr and the performance may not " + "optimal. Please use PyTorch >= 2.2.0 for better performance." + ) diff --git a/megatron/core/transformer/moe/token_dispatcher.py b/megatron/core/transformer/moe/token_dispatcher.py index 84f3d450ad..e23ea4ea0f 100644 --- a/megatron/core/transformer/moe/token_dispatcher.py +++ b/megatron/core/transformer/moe/token_dispatcher.py @@ -18,6 +18,7 @@ sort_chunks_by_idxs, unpermute, ) +from megatron.core.transformer.moe.shared_experts import SharedExpertMLP from megatron.core.transformer.transformer_config import TransformerConfig """ We use the following notation throughout this file: @@ -41,6 +42,7 @@ def __init__(self, config: TransformerConfig) -> None: Initialize the MoE Token Dispatcher. """ self.config = config + self.shared_experts: Optional[SharedExpertMLP] = None @abstractmethod def token_permutation(self, tokens: torch.Tensor, indices: torch.Tensor): @@ -71,6 +73,10 @@ def token_unpermutation( """ raise NotImplementedError("Restore function not implemented.") + def set_shared_experts(self, shared_experts): + """Set shared expert to the dispatcher.""" + self.shared_experts = shared_experts + class MoEAllGatherTokenDispatcher(MoETokenDispatcher): """ @@ -361,6 +367,8 @@ def __init__( # and "no_sync". self.cuda_sync_point = "no_sync" + self.shared_experts = None + def preprocess(self, indices: torch.Tensor) -> torch.Tensor: """ Preprocess token indices for AlltoAll communication and token permutation. This method @@ -491,6 +499,9 @@ def token_permutation( hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) tokens_per_expert = self.preprocess(indices) + if self.shared_experts is not None: + self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape)) + # Permutation 1: input to AlltoAll input self.hidden_shape_before_permute = hidden_states.shape if self.cuda_sync_point == "before_permutation_1": @@ -511,6 +522,8 @@ def token_permutation( self.output_splits, self.input_splits, ) + if self.shared_experts is not None: + self.shared_experts.linear_fc1_forward_and_act(global_input_tokens) if parallel_state.get_tensor_model_parallel_world_size() > 1: global_input_tokens = gather_from_sequence_parallel_region( @@ -574,6 +587,9 @@ def token_unpermutation( self.input_splits, self.output_splits, ) + if self.shared_experts is not None: + self.shared_experts.linear_fc2_forward(permutated_local_input_tokens) + self.shared_experts.post_forward_comm() # Unpermutation 1: Unsort input tokens to restore the original order. output = unpermute( @@ -586,4 +602,9 @@ def token_unpermutation( # Reshape the output tensor output = output.view(self.hidden_shape) + + # Add shared experts output + if self.shared_experts is not None: + shared_expert_output = self.shared_experts.get_output() + output += shared_expert_output return output, None diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 6990df9685..f16a0117a3 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -236,6 +236,16 @@ class TransformerConfig(ModelParallelConfig): #################### # MoE related #################### + moe_shared_expert_intermediate_size: int = None + """Shared expert total ffn hidden size. + It should be equal to 'num_shared_experts * ffn_size_of_each_shared_expert' if + there are multiple shared experts. + None means no shared expert.""" + + moe_shared_expert_overlap: bool = False + """Enable overlapping between shared expert computations and dispatcher communications. + Without this, the shared epxerts execute after the routed experts.""" + moe_router_load_balancing_type: str = "aux_loss" """Determines the load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss used in GShard and SwitchTransformer, "sinkhorn" corresponds to the balancing @@ -353,6 +363,20 @@ def __post_init__(self): if self.num_moe_experts is not None and self.num_moe_experts <= 0: raise ValueError('num_moe_experts must be non-negative.') + if self.moe_shared_expert_intermediate_size is not None: + if self.moe_shared_expert_intermediate_size <= 0: + raise ValueError( + f'moe_shared_expert_intermediate_size must be ' + f'num_shared_experts * ffn_size_of_each_shared_expert, ' + f'but got {self.moe_shared_expert_intermediate_size}' + ) + if self.moe_shared_expert_overlap and self.moe_token_dispatcher_type not in [ + "alltoall" + ]: + raise ValueError( + f'moe_shared_expert_overlap only works with alltoall token dispatcher.' + ) + if self.moe_expert_capacity_factor is not None: if self.moe_token_dispatcher_type not in ["alltoall", "alltoall_seq"]: raise ValueError( diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 0217d71e44..3dcfe4f2b2 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1863,6 +1863,14 @@ def _add_moe_args(parser): help='Degree of expert model parallelism.') group.add_argument('--num-experts', type=int, default=None, help='Number of Experts in MoE (None means no MoE)') + group.add_argument('--moe-shared-expert-intermediate-size', type=int, default=None, + help='Shared expert total ffn hidden size. ' + 'It should be equal to "num_shared_experts * ffn_size_of_each_shared_expert" if there are multiple shared experts. ' + 'None means no shared expert.') + group.add_argument('--moe-shared-expert-overlap', action='store_true', + help='Enable overlapping between shared expert computations and dispatcher communications. ' + 'Without this, the shared epxerts execute after the routed experts. ' + 'Only effective when moe-shared-expert-intermediate-size is set.') group.add_argument('--moe-router-load-balancing-type', type=str, choices=['aux_loss', 'sinkhorn', 'none'], default='aux_loss', diff --git a/megatron/training/training.py b/megatron/training/training.py index bcca2cbe89..7d60f41f5c 100644 --- a/megatron/training/training.py +++ b/megatron/training/training.py @@ -104,6 +104,11 @@ def num_floating_point_operations(args, batch_size): # MoE. num_experts_routed_to = 1 if args.num_experts is None else args.moe_router_topk gated_linear_multiplier = 3 / 2 if args.swiglu else 1 + shared_expert_ffn_hidden_size = ( + 0 + if args.moe_shared_expert_intermediate_size is None + else args.moe_shared_expert_intermediate_size + ) # The 12x term below comes from the following factors; for more details, see # "APPENDIX: FLOATING-POINT OPERATIONS" in https://arxiv.org/abs/2104.04473. @@ -137,6 +142,8 @@ def num_floating_point_operations(args, batch_size): * num_experts_routed_to * gated_linear_multiplier ) + # Shared Experts. + + ((shared_expert_ffn_hidden_size / args.hidden_size) * gated_linear_multiplier) # Logit. + (args.padded_vocab_size / (2 * args.num_layers * args.hidden_size)) ) diff --git a/tests/unit_tests/dist_checkpointing/models/test_sequential_mlp.py b/tests/unit_tests/dist_checkpointing/models/test_sequential_mlp.py index 111e982a35..d42b73b8af 100644 --- a/tests/unit_tests/dist_checkpointing/models/test_sequential_mlp.py +++ b/tests/unit_tests/dist_checkpointing/models/test_sequential_mlp.py @@ -49,11 +49,15 @@ def initialize_expert_layer(seed, glu=True, moe_grouped_gemm=False, **config_kwa ) if moe_grouped_gemm: model = TEGroupedMLP( - num_local_experts, transformer_config, transformer_layer_spec.submodules.mlp.submodules + num_local_experts, + transformer_config, + transformer_layer_spec.submodules.mlp.submodules.experts, ) else: model = SequentialMLP( - num_local_experts, transformer_config, transformer_layer_spec.submodules.mlp.submodules + num_local_experts, + transformer_config, + transformer_layer_spec.submodules.mlp.submodules.experts, ) return model diff --git a/tests/unit_tests/transformer/moe/test_grouped_mlp.py b/tests/unit_tests/transformer/moe/test_grouped_mlp.py index 757be59232..dea68d580f 100644 --- a/tests/unit_tests/transformer/moe/test_grouped_mlp.py +++ b/tests/unit_tests/transformer/moe/test_grouped_mlp.py @@ -85,7 +85,10 @@ def setup_method(self, method, use_cpu_initialization=False, swiglu=True): ## Grouped GEMM _set_random_seed(seed_=123, data_parallel_random_init=False) tf_config.moe_grouped_gemm = True - self.grouped_mlp = MoELayer(tf_config) + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + self.num_experts, moe_grouped_gemm=True + ) + self.grouped_mlp = MoELayer(tf_config, transformer_layer_spec.submodules.mlp.submodules) self.grouped_mlp = Float16Module(self.grouped_mlp, self.args).module print("done intializing for grouped gemm") diff --git a/tests/unit_tests/transformer/moe/test_routers.py b/tests/unit_tests/transformer/moe/test_routers.py index b1d07d054a..c1633834b6 100644 --- a/tests/unit_tests/transformer/moe/test_routers.py +++ b/tests/unit_tests/transformer/moe/test_routers.py @@ -45,6 +45,7 @@ def test_constructor(self): assert num_weights == 12 * 4, num_weights @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal @pytest.mark.parametrize("moe_router_pre_softmax", [(True), (False)]) def test_router_forward(self, moe_router_pre_softmax): with torch.no_grad(): @@ -56,6 +57,7 @@ def test_router_forward(self, moe_router_pre_softmax): scores, indices = self.router(hidden_states) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal def test_aux_loss(self): self.sequential_mlp = self.sequential_mlp.cuda() diff --git a/tests/unit_tests/transformer/moe/test_shared_experts.py b/tests/unit_tests/transformer/moe/test_shared_experts.py new file mode 100644 index 0000000000..0cacf30836 --- /dev/null +++ b/tests/unit_tests/transformer/moe/test_shared_experts.py @@ -0,0 +1,126 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. + +import pytest +import torch + +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_with_transformer_engine_spec +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed +from megatron.core.transformer.moe.moe_layer import MoELayer +from megatron.core.transformer.transformer_config import TransformerConfig +from tests.unit_tests.test_utilities import Utils + + +class TestSharedExperts: + + def setup_method(self, method): + pass + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + def test_gpu_forward(self): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + print("done intializing") + num_moe_experts = 2 + transformer_config = TransformerConfig( + num_layers=1, + hidden_size=12, + num_attention_heads=4, + num_moe_experts=num_moe_experts, + moe_shared_expert_intermediate_size=32, + use_cpu_initialization=True, + activation_func=torch.nn.functional.silu, + gated_linear_unit=True, + bias_activation_fusion=True, + moe_router_load_balancing_type="sinkhorn", + moe_router_topk=1, + add_bias_linear=False, + ) + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=num_moe_experts, moe_grouped_gemm=False + ) + self.moe_layer = MoELayer( + transformer_config, transformer_layer_spec.submodules.mlp.submodules + ) + + assert isinstance(self.moe_layer, MoELayer) + + num_weights = sum([p.numel() for p in self.moe_layer.parameters()]) + assert num_weights == 3480 + 1152 + assert self.moe_layer.shared_experts is not None + assert self.moe_layer.shared_experts.stream is None + assert self.moe_layer.token_dispatcher.shared_experts is None + + moe_layer = self.moe_layer + moe_layer.cuda() + # [sequence length, batch size, hidden size] + hidden_states = torch.ones((32, 2, moe_layer.config.hidden_size)) + hidden_states = hidden_states.cuda() + output, _ = moe_layer(hidden_states) + assert output.shape[0] == 32 + assert output.shape[1] == 2 + assert output.shape[2] == moe_layer.config.hidden_size + assert output.dtype == torch.float32 + assert output.device.type == 'cuda' + + +class TestSharedExpertsOverlap: + + def setup_method(self, method): + pass + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal + def test_gpu_forward(self): + Utils.initialize_model_parallel(1, 1) + model_parallel_cuda_manual_seed(123) + print("done intializing") + num_moe_experts = 2 + transformer_config = TransformerConfig( + num_layers=1, + hidden_size=12, + num_attention_heads=4, + num_moe_experts=num_moe_experts, + moe_shared_expert_intermediate_size=32, + moe_shared_expert_overlap=True, + moe_token_dispatcher_type="alltoall", + use_cpu_initialization=True, + activation_func=torch.nn.functional.silu, + gated_linear_unit=True, + bias_activation_fusion=True, + moe_router_load_balancing_type="sinkhorn", + moe_router_topk=1, + add_bias_linear=False, + ) + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=num_moe_experts, moe_grouped_gemm=False + ) + self.moe_layer = MoELayer( + transformer_config, transformer_layer_spec.submodules.mlp.submodules + ) + + assert isinstance(self.moe_layer, MoELayer) + + num_weights = sum([p.numel() for p in self.moe_layer.parameters()]) + assert num_weights == 3480 + 1152 + assert self.moe_layer.shared_experts is not None + assert self.moe_layer.shared_experts.stream is not None + assert self.moe_layer.token_dispatcher.shared_experts is not None + + moe_layer = self.moe_layer + moe_layer.cuda() + # [sequence length, batch size, hidden size] + hidden_states = torch.ones((32, 2, moe_layer.config.hidden_size)) + hidden_states = hidden_states.cuda() + output, _ = moe_layer(hidden_states) + assert output.shape[0] == 32 + assert output.shape[1] == 2 + assert output.shape[2] == moe_layer.config.hidden_size + assert output.dtype == torch.float32 + assert output.device.type == 'cuda' diff --git a/tests/unit_tests/transformer/moe/test_token_dispatcher.py b/tests/unit_tests/transformer/moe/test_token_dispatcher.py index ff6ceb43b9..e85f8512b4 100644 --- a/tests/unit_tests/transformer/moe/test_token_dispatcher.py +++ b/tests/unit_tests/transformer/moe/test_token_dispatcher.py @@ -232,6 +232,7 @@ def teardown_method(self, method): Utils.destroy_model_parallel() @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal @pytest.mark.parametrize("tp_size,ep_size", [(8, 1), (1, 8), (2, 4), (1, 1)]) def test_forward_backward(self, tp_size, ep_size): container = MoEModelTestContainer( @@ -247,6 +248,7 @@ def test_forward_backward(self, tp_size, ep_size): container.dispatcher_dropless_test() @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + @pytest.mark.internal @pytest.mark.parametrize("tp_size,ep_size", [(2, 4)]) def test_extend_tp_forward_backward(self, tp_size, ep_size): container = MoEModelTestContainer(