diff --git a/paddleformers/nn/moe_deepep/__init__.py b/paddleformers/nn/moe_deepep/__init__.py new file mode 100644 index 00000000000..9cf00634143 --- /dev/null +++ b/paddleformers/nn/moe_deepep/__init__.py @@ -0,0 +1,45 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from contextlib import suppress +from typing import TYPE_CHECKING + +from ...utils.lazy_import import _LazyModule + +import_structure = { + "modular_moe_layer": ["ModularMoELayer"], + "moe_communication": ["MoECommunicationInterface", "AllToAllMoECommunication", "DeepEPMoECommunication"], + "moe_expert": ["MoEExpertInterface", "StandardMoEExpert", "Qwen2MLP"], + "moe_gate": ["PretrainedMoEGate"], + "moe_factory": ["QuickAccessMoEFactory"], +} + +if TYPE_CHECKING: + from .modular_moe_layer import ModularMoELayer + from .moe_communication import ( + AllToAllMoECommunication, + DeepEPMoECommunication, + MoECommunicationInterface, + ) + from .moe_expert import MoEExpertInterface, Qwen2MLP, StandardMoEExpert + from .moe_factory import * + from .moe_gate import PretrainedMoEGate +else: + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + import_structure, + module_spec=__spec__, + ) diff --git a/paddleformers/nn/moe_deepep/modular_moe_layer.py b/paddleformers/nn/moe_deepep/modular_moe_layer.py new file mode 100644 index 00000000000..10de1b3b87f --- /dev/null +++ b/paddleformers/nn/moe_deepep/modular_moe_layer.py @@ -0,0 +1,386 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from copy import deepcopy +from typing import Any, Dict, Optional + +import paddle +import paddle.distributed as dist +from paddle import nn +from paddle.distributed import fleet +from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp, ScatterOp + +from ...transformers.configuration_utils import PretrainedConfig +from ...transformers.token_dispatcher import MoEFlexTokenDispatcher +from .moe_communication import AllToAllMoECommunication, DeepEPMoECommunication +from .moe_expert import StandardMLPExpert +from .moe_gate import StandardMoEGate +from .moe_loss_instance import get_global_loss_registry + +logger = logging.getLogger(__name__) +global_loss_registry = get_global_loss_registry() + + +class ModularMoELayer(nn.Layer): + def __init__( + self, + hidden_size: int, + moe_intermediate_size: int, + num_experts: int, + num_shared_experts: int, + num_experts_per_tok: int, + norm_topk_prob: int, + expert_activation: str, + moe_config: Dict, + model_type: str, + expert_class, + pretrained_config: Optional[PretrainedConfig] = None, + ): + + super().__init__() + self.hidden_size = hidden_size + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.num_shared_experts = num_shared_experts + self.moe_intermediate_size = moe_intermediate_size + self.expert_activation = expert_activation + self.norm_topk_prob = norm_topk_prob + self.model_type = model_type + self.expert_class = expert_class + + self.sequence_parallel = pretrained_config.get("sequence_parallel", False) + self.tensor_parallel_degree = pretrained_config.get("tensor_parallel_degree", 1) + self.seq_length = pretrained_config.get("seq_length", pretrained_config.get("max_seq_len", 1024)) + self.fuse_up_gate = pretrained_config.get("fuse_attention_ffn", False) + self.ep_communication_type = pretrained_config.get("ep_communication_type", "deepep") + try: + moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group() + except Exception: + moe_group = None + self.expert_parallel_degree = dist.get_world_size(moe_group) if moe_group is not None else 1 + + self.custom_gate = moe_config.get("custom_gate", None) + self.custom_communication = moe_config.get("custom_communication", None) + self.gate_activation = moe_config.get("gate_activation", "softmax") + self.aux_loss_weight = moe_config.get("aux_loss_weight", 0.01) + self.z_loss_weight = moe_config.get("z_loss_weight", 0.0) + self.topk_method = ( + moe_config.get("train_topk_method", "greedy") + if self.training + else moe_config.get("inference_topk_method", "greedy") + ) + self.drop_tokens = moe_config.get("drop_tokens", False) + self.use_flexible_loss = moe_config.get( + "use_flexible_loss", False + ) # TODO: use customized loss system, not implemented yet + self.expert_dropout = moe_config.get("expert_dropout", 0.0) + self.loss_configs = moe_config.get("loss_configs", None) + self.loss_combiner_name = moe_config.get("loss_combiner_name", "weighted_sum") + + self._init_expert_parallel() + if self.custom_gate is not None: + self.gate = self.custom_gate + else: + self.gate = StandardMoEGate( + num_experts=self.num_experts, + expert_hidden_size=self.hidden_size, + drop_tokens=self.drop_tokens, + topk_method=self.topk_method, + num_experts_per_tok=self.num_experts_per_tok, + norm_topk_prob=self.norm_topk_prob, + moe_config=moe_config, + seq_length=self.seq_length, + ) + + if self.expert_class is None: + self.expert_class = StandardMLPExpert + + routed_expert_pretrained_config = deepcopy(pretrained_config) + shared_expert_pretrained_config = deepcopy(pretrained_config) + if self.expert_parallel_degree <= 1 and self.sequence_parallel and self.tensor_parallel_degree > 1: + routed_expert_pretrained_config.sequence_parallel = False + shared_expert_pretrained_config.sequence_parallel = False + elif self.expert_parallel_degree > 1 and self.tensor_parallel_degree >= 1: + routed_expert_pretrained_config.tensor_parallel_degree = 1 + + expert_args = {} + expert_args["config"] = routed_expert_pretrained_config + expert_args["intermediate_size"] = self.moe_intermediate_size + # Add more arguments for different models + if self.model_type == "qwen3_moe": + pass + elif self.model_type == "glm4_moe": + pass + self.experts = nn.LayerList([self.expert_class(**expert_args) for _ in range(self.num_experts)]) + + if self.expert_parallel_degree > 1: + self.token_dispatcher = MoEFlexTokenDispatcher( + self.num_experts_per_device, self.num_experts_per_tok, self.num_experts, self.moe_group + ) + else: + self.token_dispatcher = None + + shared_expert_args = {} + shared_expert_args["config"] = shared_expert_pretrained_config + shared_expert_args["intermediate_size"] = self.moe_intermediate_size * self.num_shared_experts + if self.num_shared_experts > 0: + self.shared_experts = self.expert_class(**shared_expert_args) + else: + self.shared_experts = None + + if self.custom_communication is not None: + self.communication = self.custom_communication + else: + if self.ep_communication_type == "deepep": + self.communication = DeepEPMoECommunication() + elif self.ep_communication_type == "alltoall": + self.communication = AllToAllMoECommunication() + else: + raise ValueError( + f"Unsupported communication type: {self.ep_communication_type}, please choose from ['deepep', 'alltoall']" + ) + + if hasattr(dist, "fleet") and dist.is_initialized() and self.expert_parallel_degree > 1: + self.is_mp_moe = False + self.is_ep_moe = True + for p in self.experts.parameters(): + setattr(p, "is_moe_param", True) + setattr(p, "color", {"color": "moe_expert", "group": self.moe_grad_group}) + p.no_sync = not self.is_mp_moe + p.expert = not self.is_mp_moe + logger.info(f"expert no-sync={p.no_sync}-{p.name}") + if self.is_mp_moe or self.is_ep_moe: + p.is_distributed = True + + def _init_expert_parallel(self): + def _parse_moe_expert_parallel(num_experts: int, expert_parallel_degree: int) -> int: + """ + Args: + num_experts: Total number of experts + expert_parallel_degree: Expert parallel groups + + Returns: + moe_num_experts_per_device: Number of experts per device + """ + assert ( + num_experts >= expert_parallel_degree + ), f"expert num_experts={num_experts} >= moe_world_size={expert_parallel_degree}" + assert ( + num_experts % expert_parallel_degree == 0 + ), f"expert num_experts={num_experts} % moe_world_size={expert_parallel_degree} == 0" + + moe_num_experts_per_device = num_experts // expert_parallel_degree + return moe_num_experts_per_device + + try: + dist.fleet.get_hybrid_communicate_group() + is_fleet_init = True + except AttributeError: + is_fleet_init = False + + if is_fleet_init and self.expert_parallel_degree > 1: + self.moe_group = dist.fleet.get_hybrid_communicate_group().get_expert_parallel_group() + self.moe_grad_group = dist.fleet.get_hybrid_communicate_group().get_moe_sharding_parallel_group() + self.moe_rank = dist.get_rank(self.moe_group) + self.moe_rank = 0 if self.moe_rank < 0 else self.moe_rank + new_expert_parallel_degree = dist.get_world_size(self.moe_group) + assert ( + self.expert_parallel_degree == new_expert_parallel_degree + ), f"self.expert_parallel_degree={self.expert_parallel_degree} != moe_world_size={new_expert_parallel_degree}" + self.expert_parallel_degree = 1 if new_expert_parallel_degree < 0 else new_expert_parallel_degree + self.num_experts_per_device = _parse_moe_expert_parallel(self.num_experts, self.expert_parallel_degree) + else: + self.moe_group = None + self.moe_rank = 0 + self.expert_parallel_degree = 1 + self.num_experts_per_device = self.num_experts + + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + """ + Args: + hidden_states: Shape: [batch_size, seq_len, hidden_size] + + Returns: + output: Shape: [batch_size, seq_len, hidden_size] + """ + if self.expert_parallel_degree <= 1 and self.sequence_parallel: + hidden_states = GatherOp.apply(hidden_states) + orig_shape = hidden_states.shape + residuals = hidden_states + capacity, topk_weights, topk_indices, gates_masked, mask, priorities, aux_loss, z_loss = self.gate( + hidden_states + ) + + if self.expert_parallel_degree > 1: + output = self._forward_with_ep_parallel( + hidden_states, topk_indices, topk_weights, gates_masked, mask, priorities + ) + else: + if len(hidden_states.shape) == 3: + batch_size, seq_len, d_model = hidden_states.shape + reshaped_input = hidden_states.reshape([-1, d_model]) + else: + reshaped_input = hidden_states + output = self._forward_traditional_moe(reshaped_input, topk_indices, topk_weights) + + output = output.reshape(orig_shape) + + if self.shared_experts is not None: + shared_output = self.shared_experts(residuals) + output = output + shared_output + + if self.expert_parallel_degree <= 1 and self.sequence_parallel: + output = ScatterOp.apply(output) + + return output, aux_loss + + def _forward_traditional_moe( + self, hidden_states: paddle.Tensor, selected_experts: paddle.Tensor, topk_weights: paddle.Tensor + ) -> paddle.Tensor: + """ + Forward without expert parallelism + + Args: + hidden_states: Input hidden states, shape: [batch_size*seq_len, hidden_size] + selected_experts: TopK experts indices, shape: [seq_len, num_experts_per_tok] + topk_weights: TopK weights, shape: [seq_len, num_experts_per_tok] + + Returns: + output: Output hidden states, shape: [seq_len, hidden_size] + """ + + _, d_model = hidden_states.shape + final_hidden_states = paddle.zeros_like(hidden_states, dtype=hidden_states.dtype) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = paddle.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).transpose([2, 1, 0]) + tokens_per_expert = expert_mask.reshape([expert_mask.shape[0], -1]).sum(axis=-1) + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + top_x, idx = paddle.where(expert_mask[expert_idx]) + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + if tokens_per_expert[expert_idx] <= 0.1: + continue + current_state = hidden_states[idx, None].reshape([-1, d_model]) + current_hidden_states = expert_layer(current_state) * topk_weights[idx, top_x].unsqueeze(-1) + final_hidden_states.index_add_( + index=idx.reshape([-1]), axis=0, value=current_hidden_states.to(hidden_states.dtype) + ) + + return final_hidden_states.cast(hidden_states.dtype) + + def _forward_with_ep_parallel( + self, + hidden_states: paddle.Tensor, + topk_indices: paddle.Tensor, + topk_weights: paddle.Tensor, + gates_masked: paddle.Tensor, + mask: paddle.Tensor, + priorities: paddle.Tensor, + ) -> paddle.Tensor: + """ + Forward with expert parallelism + + Args: + hidden_states: Input hidden states, shape: [seq_len, hidden_size] + topk_indices: TopK experts indices, shape: [seq_len, num_experts_per_token] + topk_weights: TopK weights, shape: [seq_len, num_experts_per_token] + gates_masked: Masked hidden_states,形状: [seq_len, num_experts] + mask: One-hot encoding of the selected experts for each token, shape: [seq_len, num_experts] + + Returns: + output: Output hidden states, shape: [seq_len, hidden_size] + """ + output = self.communication.forward( + hidden_states, + topk_indices, + topk_weights, + gates_masked, + mask, + priorities, + self.expert_parallel_degree, + self.moe_group, + self.experts, + self.moe_rank, + self.num_experts_per_device, + self.num_experts, + self.num_experts_per_tok, + self.token_dispatcher, + ) + return output + + def get_auxiliary_loss(self) -> paddle.Tensor: + return self.gate.get_auxiliary_loss() + + def get_z_loss(self) -> paddle.Tensor: + return self.gate.get_z_loss() + + def get_all_losses(self) -> Dict[str, paddle.Tensor]: + if hasattr(self.gate, "get_all_losses"): + return self.gate.get_all_losses() + else: + return {"auxiliary": self.get_auxiliary_loss(), "z_loss": self.get_z_loss()} + + def get_total_loss(self) -> paddle.Tensor: + if hasattr(self.gate, "get_total_loss"): + return self.gate.get_total_loss() + else: + return self.get_auxiliary_loss() + self.get_z_loss() + + def remove_loss_function(self, name: str): + if not self.use_flexible_loss: + logger.warning("Current not open `use_flexible_loss`, cannot remove custom losses") + return + + if hasattr(self.gate, "remove_loss_config"): + self.gate.remove_loss_config(name) + else: + logger.warning("Current not open `remove_loss_config` on gate, cannot remove custom losses") + + def update_loss_weights(self, weights: Dict[str, float]): + if not self.use_flexible_loss: + logger.warning("Current not open `use_flexible_loss`, cannot update loss weights") + return + + if hasattr(self.gate, "update_loss_weights"): + self.gate.update_loss_weights(weights) + else: + logger.warning("Current not open `update_loss_weights` on gate, cannot update loss weights") + + def set_loss_combiner(self, combiner_name: str): + if not self.use_flexible_loss: + logger.warning("Current not open `use_flexible_loss`, cannot set loss combiner") + return + + if hasattr(self.gate, "set_loss_combiner"): + self.gate.set_loss_combiner(combiner_name) + else: + logger.warning("Current not open `set_loss_combiner` on gate, cannot set loss combiner") + + def get_expert_info(self) -> Dict[str, Any]: + return { + "num_experts": self.num_experts, + "num_experts_per_device": self.num_experts_per_device, + "expert_parallel_degree": self.expert_parallel_degree, + "moe_rank": self.moe_rank, + "is_parallel_enabled": self.expert_parallel_degree > 1, + "use_flexible_loss": self.use_flexible_loss, + } diff --git a/paddleformers/nn/moe_deepep/moe_communication.py b/paddleformers/nn/moe_deepep/moe_communication.py new file mode 100644 index 00000000000..995f0da2ebb --- /dev/null +++ b/paddleformers/nn/moe_deepep/moe_communication.py @@ -0,0 +1,276 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Any, List, Tuple + +import numpy as np +import paddle +import paddle.distributed as dist +from paddle import Tensor, nn +from paddle.distributed.communication.group import Group + + +class MoECommunicationInterface(ABC): + @abstractmethod + def forward( + self, + hidden_states: paddle.Tensor, + topk_indices: paddle.Tensor, + topk_weights: paddle.Tensor, + gates_masked: paddle.Tensor, + mask: paddle.Tensor, + priorities: paddle.Tensor, + expert_parallel_degree: int, + moe_group: Group, + experts: nn.LayerList, + moe_rank: int, + num_experts_per_device: int, + num_experts: int, + topk: int, + token_dispatcher, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """ + Args: + hidden_states: Input hidden states, shape: [batch_size*seq_len, hidden_size] or [batch_size, seq_len, hidden_size] + topk_indices: Indices of selected experts for each token, shape: [num_tokens, num_experts_per_token] + topk_weights: Weights of selected experts for each token, shape: [num_tokens, num_experts_per_token] + gates_masked: Masked gates. For each token(row), the selected experts are remainded with their normalized gate values, others are 0. Shape: [num_tokens, num_experts] + mask: Mask. For each token(row), the selected experts are marked with 1, others are 0. Shape: [num_tokens, num_experts] + priorities: Token priorities, shape: [num_tokens, num_experts] + expert_parallel_degree: Expert parallel degree + moe_group: MoE group + experts: Experts list + moe_rank: Current rank id in the MoE group + num_experts_per_device: Number of experts per device + num_experts: Total number of experts + topk: Number of experts per token + token_dispatcher: Token dispatcher + + Returns: + output: Output tensor + aux_loss: Auxiliary loss + z_loss: Z loss + """ + pass + + +class AllToAllMoECommunication(nn.Layer, MoECommunicationInterface): + """ + All-to-All EP + """ + + def forward( + self, + hidden_states: paddle.Tensor, + topk_indices: paddle.Tensor, + topk_weights: paddle.Tensor, + gates_masked: paddle.Tensor, + mask: paddle.Tensor, + priorities: paddle.Tensor, + expert_parallel_degree: int, + moe_group: Group, + experts: nn.LayerList, + moe_rank: int, + num_experts_per_device: int, + num_experts: int, + topk: int, + token_dispatcher, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + + if expert_parallel_degree <= 1: + return hidden_states + + # 1. Reshape topk_indices to a single list of all expert assignments + # Shape: [T * K] + flat_expert_indices = paddle.flatten(topk_indices) + + tokens_per_expert = paddle.bincount(x=flat_expert_indices, minlength=num_experts) + tokens_per_expert = tokens_per_expert.detach() + + idxs = topk_indices.reshape([topk_indices.shape[0] * topk_indices.shape[1]]).argsort() + sorted_tokens = hidden_states[idxs // topk_indices.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + + tokens_per_ep_rank = tokens_per_expert.reshape([expert_parallel_degree, -1]).sum(axis=1) + tokens_per_expert_group = _AllToAll.apply([tokens_per_expert.shape[0]], tokens_per_expert, group=moe_group) + + tokens_per_expert_group_sum = tokens_per_expert_group.reshape([expert_parallel_degree, -1]) + output_splits = tokens_per_expert_group_sum.sum(axis=1).cpu().tolist() + input_split_sizes = tokens_per_ep_rank.cpu().tolist() + gathered_tokens = _AllToAll.apply( + [tokens_per_expert_group.sum(axis=0).cpu().item(), sorted_tokens.shape[1]], + sorted_tokens, + out_split_sizes=output_splits, + in_split_sizes=input_split_sizes, + group=moe_group, + ) + + tokens_per_expert_post_gather = tokens_per_expert_group.reshape( + [expert_parallel_degree, num_experts_per_device] + ).sum(axis=0) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % num_experts_per_device + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = experts[i + moe_rank * num_experts_per_device] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + outs = paddle.concat(outputs, axis=0) if len(outputs) > 0 else paddle.to_tensor(0, dtype=sorted_tokens.dtype) + + new_x = paddle.empty_like(outs) + new_x[gatherd_idxs] = outs + + gathered_tokens = _AllToAll.apply( + sorted_tokens_shape, + new_x, + out_split_sizes=input_split_sizes, + in_split_sizes=output_splits, + group=moe_group, + ) + outs = gathered_tokens + + new_x = paddle.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.reshape(topk_indices.shape + [-1]) + .astype(topk_weights.dtype) + .multiply_(topk_weights.unsqueeze(-1)) + .sum(axis=1) + .astype(new_x.dtype) + ) + + return final_out + + +class DeepEPMoECommunication(nn.Layer, MoECommunicationInterface): + """ + DeepEP EP + """ + + def expert_forward(self, dispatched_input, tokens_per_expert, experts, moe_rank, num_experts_per_device): + outputs = [] + tokens_per_expert = ( + tokens_per_expert.tolist() if not isinstance(tokens_per_expert, list) else tokens_per_expert + ) + chunks = paddle.split(dispatched_input, num_or_sections=tokens_per_expert, axis=0) + for i, chunk in enumerate(chunks): + chunk = chunk.contiguous() + current_expert_idx = i + moe_rank * num_experts_per_device + expert = experts[current_expert_idx] + outputs += [expert(chunk)] + + return paddle.concat(outputs, axis=0) + + def forward( + self, + hidden_states: paddle.Tensor, + topk_indices: paddle.Tensor, + topk_weights: paddle.Tensor, + gates_masked: paddle.Tensor, + mask: paddle.Tensor, + priorities: paddle.Tensor, + expert_parallel_degree: int, + moe_group: Group, + experts: nn.LayerList, + moe_rank: int, + num_experts_per_device: int, + num_experts: int, + topk: int, + token_dispatcher, + ) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: + if expert_parallel_degree <= 1: + return hidden_states + (dispatched_input, tokens_per_expert) = token_dispatcher.token_permutation( + hidden_states, + gates_masked, + mask, + ) + expert_output = self.expert_forward( + dispatched_input, tokens_per_expert, experts, moe_rank, num_experts_per_device + ) + output, _ = token_dispatcher.token_unpermutation(expert_output, None) + return output + + +class _AllToAll(paddle.autograd.PyLayer): + @staticmethod + def forward( + ctx: Any, + output_shape: List, + input: Tensor, + out_split_sizes: List = None, + in_split_sizes: List = None, + group: Group = None, + ) -> Tensor: + """ + All-to-all communication in the group. + Args: + ctx (Any): Context object. + output_shape (List): Output shape. + input (Tensor): Input tensor. + out_split_sizes (List): Output split sizes. + in_split_sizes (List): Input split sizes. + group (Group): The group object. + Returns: + Tensor: Output tensor. + """ + + ctx.group = group + ctx.input_shape = input.shape + ctx.out_split_sizes = out_split_sizes + ctx.in_split_sizes = in_split_sizes + + # return input + if dist.get_world_size(group) <= 1: + return input + + output = paddle.empty(output_shape, dtype=input.dtype) + task = dist.alltoall_single( + output, + input, + out_split_sizes=out_split_sizes, + in_split_sizes=in_split_sizes, + sync_op=False, + group=group, + ) + task.wait() + + return output + + @staticmethod + def backward(ctx: Any, *grad_output: Tensor) -> Tuple[Tensor]: + """ + Aggregates gradient information from all input tensors into a single tensor. + Args: + ctx (Any): The context object used to store information that needs to be passed. + *grad_output (Tensor): A list of input tensors whose gradients are to be aggregated. + Returns: + Tuple[Tensor]: A tuple containing a tensor that holds the gradients of all input tensors. + """ + # return grad_output + return _AllToAll.apply(ctx.input_shape, *grad_output, ctx.in_split_sizes, ctx.out_split_sizes, ctx.group) diff --git a/paddleformers/nn/moe_deepep/moe_expert.py b/paddleformers/nn/moe_deepep/moe_expert.py new file mode 100644 index 00000000000..02bcbd19f3c --- /dev/null +++ b/paddleformers/nn/moe_deepep/moe_expert.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +import paddle + +from ...nn.mlp import MLP +from ...transformers.configuration_utils import PretrainedConfig + + +class MoEExpertInterface(ABC): + @abstractmethod + def forward(self, hidden_states: paddle.Tensor) -> paddle.Tensor: + """ + Args: + hidden_states: Input hidden states + + Returns: + output: Output hidden states + """ + pass + + +class StandardMLPExpert(MLP): + def __init__( + self, + config: PretrainedConfig, + moe_intermediate_size: int, + ): + super().__init__(config=config, intermediate_size=moe_intermediate_size) diff --git a/paddleformers/nn/moe_deepep/moe_factory.py b/paddleformers/nn/moe_deepep/moe_factory.py new file mode 100644 index 00000000000..be3163b486f --- /dev/null +++ b/paddleformers/nn/moe_deepep/moe_factory.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ...transformers.configuration_utils import PretrainedConfig +from .modular_moe_layer import ModularMoELayer + + +class QuickAccessMoEFactory: + @staticmethod + def create_from_model_name( + pretrained_config: PretrainedConfig, + expert_class, + gate_activation: str, + expert_activation: str, + train_topk_method: str, + inference_topk_method: str, + drop_tokens: bool, + ) -> ModularMoELayer: + model_type = getattr(pretrained_config, "model_type", None) + if model_type is None: + raise ValueError("Cannot determine model type from pretrained_config") + + moe_config = { + "gate_activation": gate_activation, + "expert_activation": expert_activation, + "train_topk_method": train_topk_method, + "inference_topk_method": inference_topk_method, + "drop_tokens": drop_tokens + # TODO: support aux_loss_weight, z_loss_weight, expert_dropout, use_flexible_loss, loss_configs + } + + return ModularMoELayer( + hidden_size=pretrained_config.hidden_size, + moe_intermediate_size=pretrained_config.moe_intermediate_size, + num_experts=pretrained_config.get( + "num_experts", pretrained_config.get("n_routed_experts", pretrained_config.get("moe_num_experts", -1)) + ), + num_shared_experts=pretrained_config.get( + "n_shared_experts", pretrained_config.get("moe_num_shared_experts", 0) + ), + num_experts_per_tok=pretrained_config.get("num_experts_per_tok", pretrained_config.get("moe_k", -1)), + norm_topk_prob=pretrained_config.get("norm_topk_prob", True), + expert_activation=pretrained_config.get("hidden_act", pretrained_config.get("expert_activation", "silu")), + moe_config=moe_config, + model_type=model_type, + expert_class=expert_class, + pretrained_config=pretrained_config, + ) + + +__all__ = ["QuickAccessMoEFactory"] diff --git a/paddleformers/nn/moe_deepep/moe_gate.py b/paddleformers/nn/moe_deepep/moe_gate.py new file mode 100644 index 00000000000..d77cb19bd12 --- /dev/null +++ b/paddleformers/nn/moe_deepep/moe_gate.py @@ -0,0 +1,521 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) Microsoft Corporation. +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import Dict, Tuple + +import paddle +import paddle.distributed as dist +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.distributed.fleet.utils.sequence_parallel_utils import AllGatherOp + +from ...utils.log import logger + + +class MoEGateMixin: + def gate_score_func(self, logits: paddle.Tensor) -> paddle.Tensor: + with paddle.amp.auto_cast(False): + # [..., hidden_dim] -> [..., num_experts] + scoring_func = getattr(self, "scoring_func", None) + if scoring_func == "softmax": + scores = F.softmax(logits, axis=-1) + elif scoring_func == "sigmoid": + scores = F.sigmoid(logits) + elif scoring_func == "tanh": + scores = F.tanh(logits) + elif scoring_func == "relu": + scores = F.relu(logits) + elif scoring_func == "gelu": + scores = F.gelu(logits) + elif scoring_func == "leaky_relu": + scores = F.leaky_relu(logits) + else: + logger.warning_once( + f"insupportable scoring function for MoE gating: {scoring_func}, use softmax instead" + ) + scores = F.softmax(logits, axis=-1) + return scores + + def gumbel_rsample(self, logits: paddle.Tensor) -> paddle.Tensor: + gumbel = paddle.distribution.gumbel.Gumbel(0, 1) + return gumbel.rsample(logits.shape) + + def uniform_sample(self, logits: paddle.Tensor) -> paddle.Tensor: + uniform = paddle.distribution.uniform.Uniform(0, 1) + return uniform.sample(logits.shape) + + @paddle.no_grad() + def _one_hot_to_float(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.get_default_dtype()) + + @paddle.no_grad() + def _one_hot_to_int64(self, x, num_classes): + if x.dtype not in (paddle.int32, paddle.int64): + x = paddle.cast(x, paddle.int64) + return F.one_hot(x, num_classes=num_classes).cast(paddle.int64) + + @paddle.no_grad() + def _capacity( + self, + gates: paddle.Tensor, + capacity_factor: float, + max_capacity: int, + min_capacity: int, + ) -> paddle.Tensor: + """Calculate the capacity for each expert based on the gates and capacity factor. + + Args: + gates (paddle.Tensor): A tensor of shape [num_tokens, num_experts] representing the probability distribution + over experts for each token. + capacity_factor (float): A scalar float value representing the capacity factor for each expert. + min_capacity (int): A scalar integer value representing the minimum capacity for each expert. + + Returns: + int: A tensor value representing the calculated capacity for each expert. + """ + assert gates.ndim == 2, f"gates should be 2D, but got {gates.ndim}, {gates.shape}" + # gates has shape of SE + num_tokens = gates.shape[0] + num_experts = gates.shape[1] + capacity = int((num_tokens // num_experts) * capacity_factor) + if capacity < min_capacity: + capacity = min_capacity + if capacity > max_capacity: + capacity = max_capacity + assert capacity > 0, f"requires capacity > 0, capacity_factor: {capacity_factor}, input_shape: {gates.shape}" + + return capacity + + def _cal_aux_loss(self, gates, mask): + """ + Calculate auxiliary loss + + Args: + gates (paddle.Tensor): Represents the output probability of each expert. The shape is [batch_size, num_experts] + mask (paddle.Tensor): Represents whether each sample belongs to a certain expert. The shape is [batch_size, num_experts] + + Returns: + paddle.Tensor: The value of auxiliary loss. + + """ + # TODO: @DrownFish19 update aux_loss for Qwen2MoE and DeepSeekV2&V3 + me = paddle.mean(gates, axis=0) + ce = paddle.mean(mask.cast("float32"), axis=0) + if self.global_aux_loss: + me_list, ce_list = [], [] + dist.all_gather(me_list, me, group=self.group) + dist.all_gather(ce_list, ce, group=self.group) + + me_list[self.rank] = me + ce_list[self.rank] = ce + me = paddle.stack(me_list).mean(0) + ce = paddle.stack(ce_list).mean(0) + aux_loss = paddle.sum(me * ce) * float(self.num_experts) + return aux_loss + + def _cal_seq_aux_loss(self, probs, top_k, routing_map, seq_length): + max_seq_len = seq_length + + sub_max_seq_len = max_seq_len + if hasattr(self, "moe_subbatch_token_num") and self.moe_subbatch_token_num > 0: + sub_max_seq_len = self.moe_subbatch_token_num * self.tensor_parallel_degree + + # all_probs and routing_map should be computed using the runtime local sequence length on each worker. + if self.tensor_parallel_degree > 1: + assert self.sequence_parallel and max_seq_len % self.tensor_parallel_degree == 0 + local_seq_len = sub_max_seq_len // self.tensor_parallel_degree + # [B*S, E] + all_probs = AllGatherOp.apply(probs) + # [B, S, E] + all_probs = all_probs.reshape([-1, sub_max_seq_len, self.num_experts]) + batch_size = all_probs.shape[0] + # [B, S, E] + routing_map = routing_map.reshape([batch_size, local_seq_len, -1]) + else: + # [B, S, E] + all_probs = probs + batch_size, local_seq_len, _ = probs.shape + routing_map = routing_map.reshape([batch_size, local_seq_len, -1]) + + seq_axis = 1 + # Both cost_coeff and seq_aux_loss must be computed with the global sequence length visible to all workers. + # [B, E] + cost_coeff = routing_map.sum(axis=seq_axis, dtype="float32") / paddle.to_tensor( + max_seq_len * top_k / self.num_experts, dtype="float32" + ) + # [B, E] -> [B] -> [] + seq_aux_loss = (cost_coeff * all_probs.sum(axis=seq_axis) / max_seq_len).sum(axis=1).mean() + + return seq_aux_loss + + def _cal_z_loss(self, logits) -> paddle.Tensor: + """ + Calculate the z loss. + + Args: + logits (paddle.Tensor): Model output. The shape is [batch_size, num_experts]. + + Returns: + paddle.Tensor: The z loss value. + """ + l_zloss = paddle.logsumexp(logits, axis=1).square().mean() + return l_zloss + + def _cal_orthogonal_loss(self) -> paddle.Tensor: + """Gate weight orthogonal loss. + + Returns: + Paddle.Tensor: orthogonal loss + """ + weight = F.normalize(self.weight, axis=0) + orthogonal_loss = paddle.mean(paddle.square(paddle.matmul(weight.T, weight) - paddle.eye(self.num_experts))) + return orthogonal_loss + + def _priority(self, topk_idx: paddle.Tensor, capacity: int) -> paddle.Tensor: + """_summary_ + The priority is the cumulative sum of the expert indices. + + This method is used in hunyuan model + Args: + topk_idx (paddle.Tensor): [batch_size * seq_len, topk] + + Returns: + paddle.Tensor: cumsum locations + """ + _, k = topk_idx.shape + # Shape: [seq_len * k] + chosen_expert = topk_idx.reshape([-1]) + # Shape: [seq_len * k, num_experts]. + token_priority = F.one_hot(chosen_expert, self.num_experts).cast(paddle.int32) + token_priority = paddle.logical_and(token_priority > 0, token_priority.cumsum(axis=0) <= capacity) + # Shape: [seq_len, num_experts]. + token_priority = token_priority.reshape([-1, k, self.num_experts]).sum(axis=1) + + return (token_priority > 0.0).astype("float32") + + def _probs_drop_policy( + self, + scores: paddle.Tensor, + capacity: int, + ) -> paddle.Tensor: + """ + Implements the Probability-based (Probs) drop policy to enforce expert capacity. + + A token is assigned (mask value 1.0) to an expert if: + 1. It chose that expert (score > 0). (Implicitly handled by input scores). + 2. Its score for that expert is among the top 'capacity' scores for that expert. + + Args: + scores (paddle.Tensor): [num_tokens, num_total_experts]. + This should already contain zeros for non-selected + experts (i.e., the result of top-K gating). + capacity (int): The maximum number of tokens any single expert can handle. + (Not strictly used here, but good practice to include). + + Returns: + paddle.Tensor: [num_tokens, num_total_experts] boolean mask (converted to float). + 1.0 = Assigned and within capacity. 0.0 = Dropped or unassigned. + """ + num_tokens, num_experts = scores.shape + + # --- Step 1: Find the 'capacity' best tokens for *each* expert --- + + # Use paddle.topk along dim=0 (the token dimension) to find the indices + # of the tokens that have the highest scores for each expert (column). + # Since 'scores' has shape [Tokens, Experts], dim=0 returns the token indices. + + # topk_token_indices has shape [capacity, num_total_experts] + # It tells us WHICH tokens (row indices) are prioritized by capacity. + + # We use min(num_tokens, capacity) just in case there are fewer tokens than capacity. + k_to_use = min(num_tokens, capacity) + + # We only care about the indices of the selected tokens + _, topk_token_indices = paddle.topk( + scores, k=k_to_use, dim=0, sorted=True # Sorted=True is usually faster, but we only use the indices. + ) + + # --- Step 2: Create the final assignment mask using scatter --- + + # Initialize the mask to all zeros (tokens are initially dropped/unassigned). + # We use boolean type for efficient scattering, then convert to float later. + final_mask = paddle.zeros(num_tokens, num_experts, dtype=paddle.bool) + + # 2a. Create the column indices for the assignment. + # We need a tensor of shape [k_to_use, num_experts] where each row is [0, 1, 2, ..., num_experts-1]. + col_indices = paddle.arange(num_experts).unsqueeze(0).expand_as(topk_token_indices) + + # 2b. Flatten the row (token) and column (expert) indices for advanced indexing. + token_indices_flat = topk_token_indices.flatten() + col_indices_flat = col_indices.flatten() + + # 2c. Use advanced indexing to set the mask positions to True. + # This sets mask[token_index, expert_index] = True for all prioritized tokens. + final_mask[token_indices_flat, col_indices_flat] = True + + # --- Step 3: Ensure only originally selected tokens are kept --- + + # Since paddle.topk can pick up tokens with score 0 if num_tokens < capacity, + # we must ensure that we only keep tokens that had a positive score initially. + # This step implicitly cleans up any spurious assignments made by topk on zero scores. + + token_priority_mask = final_mask.float() * (scores > 0).float() + + return token_priority_mask + + def _topk_greedy(self, scores: paddle.Tensor, k: int) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + """ + topk_weight, topk_idx = paddle.topk(scores, k=k, axis=-1, sorted=True) + + return topk_weight, topk_idx + + def _topk_group_limited_greedy( + self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts in each group + n_groups (int): the number of groups for all experts + topk_group (int): the number of groups selected + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + + Note: the group size is normal greater than the number of k + """ + bsz_seq_len, n_experts = scores.shape + assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" + + group_scores = scores.reshape([0, n_group, -1]).max(axis=-1) # [n, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0), axis=-1) # fmt:skip + score_mask = ( + group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) + ) # [n, e] + tmp_scores = scores * score_mask # [n, e] + topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=True) + + return topk_weight, topk_idx + + def _topk_noaux_tc( + self, scores: paddle.Tensor, k: int, n_group: int, topk_group: int + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """_summary_ + + Args: + scores (paddle.Tensor): [bsz*seq_len, n_experts] + k (int): select the top k experts in each group + n_groups (int): the number of groups for all experts + topk_group (int): the number of groups selected + + Returns: + Tuple[paddle.Tensor, paddle.Tensor]: topk_weight, topk_idx + topk_weight: [bsz*seq_len, k] + topk_idx: [bsz*seq_len, k] + + Note: the group size is normal greater than the number of k + """ + bsz_seq_len, n_experts = scores.shape + assert n_experts % n_group == 0, "n_experts must be divisible by n_groups" + + assert self.e_score_correction_bias is not None, "e_score_correction_bias is None" + scores_for_choice = scores.reshape([bsz_seq_len, -1]) + self.e_score_correction_bias.detach().unsqueeze(0) + group_scores = ( + scores_for_choice.reshape([bsz_seq_len, self.n_group, -1]).topk(2, axis=-1)[0].sum(axis=-1) + ) # fmt:skip [n, n_group] + group_idx = paddle.topk(group_scores, k=topk_group, axis=-1, sorted=True)[1] # [n, top_k_group] + group_mask = paddle.zeros_like(group_scores).put_along_axis(group_idx, paddle.to_tensor(1.0, dtype="float32"), axis=-1) # fmt:skip + score_mask = ( + group_mask.unsqueeze(-1).expand([bsz_seq_len, n_group, n_experts // n_group]).reshape([bsz_seq_len, -1]) + ) # [n, e] + tmp_scores = scores_for_choice * score_mask # [n, e] + topk_weight, topk_idx = paddle.topk(tmp_scores, k=k, axis=-1, sorted=True) + topk_weight = scores.take_along_axis(topk_idx, axis=1) if not self.training else topk_weight + + return topk_weight, topk_idx + + +# Modified from PretrainedMoEGate +class StandardMoEGate(nn.Layer, MoEGateMixin): + def __init__( + self, + num_experts: int, + expert_hidden_size: int, + drop_tokens: bool, + topk_method: str, + num_experts_per_tok: int, + norm_topk_prob: bool, + moe_config: Dict, + seq_length: int, + ): + super(StandardMoEGate, self).__init__() + + self.num_experts = num_experts + self.expert_hidden_size = expert_hidden_size + self.drop_tokens = drop_tokens + self.topk_method = topk_method + self.num_experts_per_tok = num_experts_per_tok + self.norm_topk_prob = norm_topk_prob + # force keep in float32 when using amp + self._cast_to_low_precision = False + self.seq_length = seq_length + + self.scoring_func = moe_config.get("scoring_func", "softmax") + self.capacity_factor = moe_config.get("capacity_factor", 1.0) + self.eval_capacity_factor = moe_config.get("eval_capacity_factor", 1.0) + self.min_capacity = moe_config.get("min_capacity", 1) + self.max_capacity = moe_config.get("max_capacity", pow(2, 32)) + self.group = moe_config.get("group", None) + self.global_aux_loss = moe_config.get("global_aux_loss", False) + self.use_rts = moe_config.get("use_rts", True) + self.top2_2nd_expert_sampling = moe_config.get("top2_2nd_expert_sampling", True) + self.drop_policy = moe_config.get("drop_policy", "probs") + self.n_group = moe_config.get("n_group", 1) # for group_limited_greedy + self.topk_group = moe_config.get("topk_group", 1) # for group_limited_greedy + self.routed_scaling_factor = moe_config.get("routed_scaling_factor", 1.0) + self.seq_aux = moe_config.get("seq_aux", False) + + if self.global_aux_loss: + assert self.group is not None, "group is required when global_aux_loss is True" + self.rank = dist.get_rank(self.group) + + self.weight = paddle.create_parameter( + shape=[self.expert_hidden_size, self.num_experts], + dtype="float32", + default_initializer=paddle.nn.initializer.Uniform(), + ) + + def forward( + self, + gates: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + return self.topkgating(gates) + + def topkgating( + self, + gates: paddle.Tensor, + ) -> Tuple[int, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + """Implements TopKGating on logits.""" + + if len(gates.shape) == 3: + batch_size, seq_len, d_model = gates.shape + gates = gates.reshape([-1, d_model]) + elif len(gates.shape) == 2: + batch_size_seq_len, d_model = gates.shape + + gates_ori = gates + + logits = F.linear(gates, self.weight) + + gates = self.gate_score_func(logits=logits) + + l_zloss = self._cal_z_loss(gates) + + if self.topk_method == "greedy": + top_gate, top_idx = self._topk_greedy(gates, k=self.num_experts_per_tok) + elif self.topk_method == "group_limited_greedy": + top_gate, top_idx = self._topk_group_limited_greedy( + gates, k=self.num_experts_per_tok, n_group=self.n_group, topk_group=self.topk_group + ) + elif self.topk_method == "noaux_tc": + top_gate, top_idx = self._topk_noaux_tc( + gates, k=self.num_experts_per_tok, n_group=self.n_group, topk_group=self.topk_group + ) + else: + raise NotImplementedError(f"Invalid topk_method: {self.topk_method}") + + # norm gate to sum 1 + if self.num_experts_per_tok > 1 and self.norm_topk_prob: + denominator = top_gate.sum(axis=-1, keepdim=True) + 1e-20 + top_gate = top_gate / denominator + top_gate = top_gate * self.routed_scaling_factor + + mask = paddle.zeros_like(gates).put_along_axis(top_idx, paddle.to_tensor(1.0, dtype=gates.dtype), axis=1) + + if self.seq_aux: + l_aux = self._cal_seq_aux_loss(gates_ori, self.num_experts_per_tok, mask, self.seq_length) + else: + l_aux = self._cal_aux_loss(gates, mask) + + exp_counts = paddle.sum(mask.cast(paddle.int64), axis=0) + + if self.drop_tokens: + # Calculate configured capacity and remove locations outside capacity from mask + capacity = self._capacity( + gates, + self.capacity_factor * self.num_experts_per_tok, + self.max_capacity, + self.min_capacity, + ) + + # update mask and locations by capacity + if self.drop_policy == "probs": + topk_masked_gates = paddle.zeros_like(gates).put_along_axis(top_idx, top_gate, axis=1) + token_priority = self._probs_drop_policy(topk_masked_gates, capacity) + + elif self.drop_policy == "position": + token_priority = self._priority(top_idx, capacity) + else: + raise ValueError(f"Invalid drop_policy: {self.drop_policy}") + else: + # Do not drop tokens - set capacity according to current expert assignments + local_capacity = paddle.max(exp_counts) + if self.group is not None: + dist.all_reduce(local_capacity, op=dist.ReduceOp.MAX, group=self.group) + capacity = int(local_capacity) + token_priority = self._priority(top_idx, capacity) + + # normalize gates + gates_masked = gates * mask + + # if self.training: + gates_s = paddle.sum(gates_masked, axis=-1, keepdim=True) + denom_s = paddle.clip(gates_s, min=paddle.finfo(gates_masked.dtype).eps) + if self.norm_topk_prob: + gates_masked = gates_masked / denom_s + gates_masked = gates_masked.to(gates.dtype) + gates_masked *= self.routed_scaling_factor + + return ( + capacity, # new capacity + top_gate, # weights of selected experts for each token [num_tokens, num_experts_per_token] + top_idx, # indices of selected experts for each token [num_tokens, num_experts_per_token] + gates_masked.to( + paddle.float32 + ), # masked gates. for each token, the selected experts are remainded with their original values, others are 0 [num_tokens, num_experts] + mask, # mask. for each token, the selected experts are marked with 1s [num_tokens, num_experts] + token_priority.take_along_axis(top_idx, axis=-1), # token priority + l_aux, + l_zloss, + ) diff --git a/paddleformers/nn/moe_deepep/moe_loss.py b/paddleformers/nn/moe_deepep/moe_loss.py new file mode 100644 index 00000000000..e1c4a5edf29 --- /dev/null +++ b/paddleformers/nn/moe_deepep/moe_loss.py @@ -0,0 +1,226 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from enum import Enum +from typing import Any, Dict, List, Optional, Protocol + +import paddle + +logger = logging.getLogger(__name__) + + +class LossType(Enum): + AUXILIARY = "auxiliary" + Z_LOSS = "z_loss" + ENTROPY = "entropy" + SPARSITY = "sparsity" + DIVERSITY = "diversity" + CUSTOM = "custom" + + +@dataclass +class LossConfig: + + name: str + loss_type: LossType + weight: float = 0.0 + enabled: bool = True + params: Dict[str, Any] = None + + def __post_init__(self): + if self.params is None: + self.params = {} + + +class LossFunction(Protocol): + def __call__( + self, + routing_weights: paddle.Tensor, + selected_experts: paddle.Tensor, + gate_logits: Optional[paddle.Tensor] = None, + **kwargs + ) -> paddle.Tensor: + pass + + +class LossCombiner(Protocol): + def __call__(self, losses: Dict[str, paddle.Tensor], configs: Dict[str, LossConfig]) -> paddle.Tensor: + pass + + +class LossRegistry: + def __init__(self): + self._loss_functions: Dict[str, LossFunction] = {} + self._loss_combiners: Dict[str, LossCombiner] = {} + self._register_default_losses() + self._register_default_combiners() + + def _register_default_losses(self): + self.register_loss("auxiliary", self._auxiliary_loss) + self.register_loss("z_loss", self._z_loss) + self.register_loss("entropy", self._entropy_loss) + self.register_loss("sparsity", self._sparsity_loss) + self.register_loss("diversity", self._diversity_loss) + + def _register_default_combiners(self): + self.register_combiner("weighted_sum", self._weighted_sum_combiner) + self.register_combiner("adaptive_sum", self._adaptive_sum_combiner) + self.register_combiner("geometric_mean", self._geometric_mean_combiner) + + def register_loss(self, name: str, loss_func: LossFunction): + self._loss_functions[name] = loss_func + logger.info(f"Registering loss function: {name}") + + def register_combiner(self, name: str, combiner: LossCombiner): + self._loss_combiners[name] = combiner + logger.info(f"Registering loss combiner: {name}") + + def get_loss(self, name: str) -> Optional[LossFunction]: + return self._loss_functions.get(name) + + def get_combiner(self, name: str) -> Optional[LossCombiner]: + return self._loss_combiners.get(name) + + def list_losses(self) -> List[str]: + return list(self._loss_functions.keys()) + + def list_combiners(self) -> List[str]: + return list(self._loss_combiners.keys()) + + def _auxiliary_loss( + self, + routing_weights: paddle.Tensor, + selected_experts: paddle.Tensor, + gate_logits: Optional[paddle.Tensor] = None, + **kwargs + ) -> paddle.Tensor: + num_experts = kwargs.get("num_experts", selected_experts.max().item() + 1) + expert_usage = paddle.zeros([num_experts], dtype=routing_weights.dtype) + + for i in range(selected_experts.shape[0]): + for j in range(selected_experts.shape[1]): + expert_idx = selected_experts[i, j].item() + expert_usage[expert_idx] += routing_weights[i, j] + + expert_usage = expert_usage / selected_experts.shape[0] + aux_loss = paddle.sum(expert_usage * paddle.log(expert_usage + 1e-8)) + return aux_loss + + def _z_loss( + self, + routing_weights: paddle.Tensor, + selected_experts: paddle.Tensor, + gate_logits: Optional[paddle.Tensor] = None, + **kwargs + ) -> paddle.Tensor: + if gate_logits is None: + return paddle.to_tensor(0.0) + return paddle.sum(gate_logits**2) + + def _entropy_loss( + self, + routing_weights: paddle.Tensor, + selected_experts: paddle.Tensor, + gate_logits: Optional[paddle.Tensor] = None, + **kwargs + ) -> paddle.Tensor: + """Entropy loss - encourage the diversity of routing weights""" + return -paddle.sum(routing_weights * paddle.log(routing_weights + 1e-8)) + + def _sparsity_loss( + self, + routing_weights: paddle.Tensor, + selected_experts: paddle.Tensor, + gate_logits: Optional[paddle.Tensor] = None, + **kwargs + ) -> paddle.Tensor: + """Sparsety loss - encourage the sparsity of expert selection""" + num_experts = kwargs.get("num_experts", selected_experts.max().item() + 1) + expert_usage = paddle.zeros([num_experts]) + + for i in range(selected_experts.shape[0]): + for j in range(selected_experts.shape[1]): + expert_idx = selected_experts[i, j].item() + expert_usage[expert_idx] += 1 + + return paddle.sum(paddle.abs(expert_usage)) + + def _diversity_loss( + self, + routing_weights: paddle.Tensor, + selected_experts: paddle.Tensor, + gate_logits: Optional[paddle.Tensor] = None, + **kwargs + ) -> paddle.Tensor: + """Diversity loss - encourage the diversity of expert selection""" + num_experts = kwargs.get("num_experts", selected_experts.max().item() + 1) + expert_counts = paddle.zeros([num_experts]) + + for i in range(selected_experts.shape[0]): + for j in range(selected_experts.shape[1]): + expert_idx = selected_experts[i, j].item() + expert_counts[expert_idx] += 1 + + uniform_dist = paddle.ones_like(expert_counts) / expert_counts.shape[0] + diversity_loss = paddle.nn.functional.kl_div( + paddle.log(expert_counts + 1e-8), paddle.log(uniform_dist + 1e-8), reduction="sum" + ) + return diversity_loss + + # 默认损失组合器实现 + def _weighted_sum_combiner( + self, losses: Dict[str, paddle.Tensor], configs: Dict[str, LossConfig] + ) -> paddle.Tensor: + combined_loss = paddle.to_tensor(0.0) + for name, loss_value in losses.items(): + config = configs.get(name) + if config and config.enabled: + combined_loss += config.weight * loss_value + return combined_loss + + def _adaptive_sum_combiner( + self, losses: Dict[str, paddle.Tensor], configs: Dict[str, LossConfig] + ) -> paddle.Tensor: + combined_loss = paddle.to_tensor(0.0) + enabled_losses = [ + loss for name, loss in losses.items() if configs.get(name, LossConfig("", LossType.CUSTOM)).enabled + ] + + if len(enabled_losses) > 1: + loss_std = paddle.std(paddle.stack(enabled_losses)) + else: + loss_std = paddle.to_tensor(1.0) + + adaptation_factor = 0.1 + for name, loss_value in losses.items(): + config = configs.get(name) + if config and config.enabled: + adaptive_weight = config.weight * (1 + adaptation_factor * loss_std) + combined_loss += adaptive_weight * loss_value + + return combined_loss + + def _geometric_mean_combiner( + self, losses: Dict[str, paddle.Tensor], configs: Dict[str, LossConfig] + ) -> paddle.Tensor: + combined_loss = paddle.to_tensor(1.0) + for name, loss_value in losses.items(): + config = configs.get(name) + if config and config.enabled and config.weight > 0: + combined_loss *= (loss_value + 1e-8) ** config.weight + return combined_loss diff --git a/paddleformers/nn/moe_deepep/moe_loss_instance.py b/paddleformers/nn/moe_deepep/moe_loss_instance.py new file mode 100644 index 00000000000..22cb58924af --- /dev/null +++ b/paddleformers/nn/moe_deepep/moe_loss_instance.py @@ -0,0 +1,64 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional + +import paddle + +from .moe_loss import LossConfig, LossRegistry + + +def get_global_loss_registry(): + if not hasattr(get_global_loss_registry, "_instance"): + get_global_loss_registry._instance = LossRegistry() + get_global_loss_registry._instance.register_loss("custom_diversity_loss1", custom_diversity_loss) + get_global_loss_registry._instance.register_combiner( + "custom_weighted_sum_combiner1", custom_weighted_sum_combiner + ) + return get_global_loss_registry._instance + + +def custom_diversity_loss( + routing_weights: paddle.Tensor, + selected_experts: paddle.Tensor, + gate_logits: Optional[paddle.Tensor] = None, + **kwargs +) -> paddle.Tensor: + num_experts = kwargs.get("num_experts", 8) + expert_counts = paddle.zeros([num_experts]) + + for i in range(selected_experts.shape[0]): + for j in range(selected_experts.shape[1]): + expert_idx = selected_experts[i, j].item() + expert_counts[expert_idx] += 1 + + uniform_dist = paddle.ones_like(expert_counts) / expert_counts.shape[0] + expert_probs = expert_counts / (expert_counts.sum() + 1e-8) + + diversity_loss = paddle.nn.functional.kl_div( + paddle.log(expert_probs + 1e-8), paddle.log(uniform_dist + 1e-8), reduction="sum" + ) + + return diversity_loss + + +def custom_weighted_sum_combiner( + self, losses: Dict[str, paddle.Tensor], configs: Dict[str, LossConfig] +) -> paddle.Tensor: + combined_loss = paddle.to_tensor(0.0) + for name, loss_value in losses.items(): + config = configs.get(name) + if config and config.enabled: + combined_loss += config.weight * loss_value + return combined_loss diff --git a/paddleformers/transformers/configuration_utils.py b/paddleformers/transformers/configuration_utils.py index c014acde835..954db4fa26d 100644 --- a/paddleformers/transformers/configuration_utils.py +++ b/paddleformers/transformers/configuration_utils.py @@ -303,6 +303,7 @@ class LlmMetaConfig: moe_attributes = [ ("moe_subbatch_token_num", int, 0, "The number of tokens in each subbatch for MoE model processing."), ("using_fake_gate", bool, False, "Whether to fake gate."), + ("ep_communication_type", str, "deepep", 'Communication type used by MoE module "deepep" or "alltoall". '), ] mtp_attributes = [ @@ -509,7 +510,8 @@ class PretrainedConfig: `"single_label_classification"` or `"multi_label_classification"`. moe_subbatch_token_num (`int`, *optional*, defaults to 0): The number of tokens in a subbatch for MoE. - + ep_communication_type (`str`, *optional*, defaults to `deepep`): + Communication type for expert parallel. Can be one of `deepep`, `alltoall`. > Parameters for general components _attn_implementation (`str`, defaults to `eager`) @@ -654,6 +656,7 @@ def __init__(self, **kwargs): self.kto_config = kwargs.pop("kto_config", None) self.moe_subbatch_token_num = kwargs.pop("moe_subbatch_token_num", 0) + self.ep_communication_type = kwargs.pop("ep_communication_type", "deepep") self.using_fake_gate = kwargs.pop("using_fake_gate", False) # Tokenizer arguments TODO: eventually tokenizer and models should share the same config diff --git a/paddleformers/transformers/qwen3_moe/modeling.py b/paddleformers/transformers/qwen3_moe/modeling.py index b34759d0aaf..da4433b63ea 100644 --- a/paddleformers/transformers/qwen3_moe/modeling.py +++ b/paddleformers/transformers/qwen3_moe/modeling.py @@ -20,8 +20,10 @@ from typing import List, Optional, Tuple, Union import paddle +import paddle.distributed as dist import paddle.nn.functional as F from paddle import Tensor, nn +from paddle.distributed import fleet from paddle.distributed.fleet.utils import recompute from paddle.distributed.fleet.utils.sequence_parallel_utils import GatherOp, ScatterOp @@ -31,6 +33,7 @@ from ...nn.linear import Linear as GeneralLinear from ...nn.lm_head import LMHead as GeneralLMHead from ...nn.mlp import MLP +from ...nn.moe_deepep.moe_factory import QuickAccessMoEFactory from ...nn.norm import Norm as GeneralNorm from ...nn.pp_model import GeneralModelForCausalLMPipe from ...utils.log import logger @@ -321,10 +324,27 @@ def __init__(self, config: Qwen3MoeConfig, layer_idx: int): self.self_attn = Qwen3MoeAttention(config, layer_idx) + try: + moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group() + except: + moe_group = None + expert_parallel_degree = dist.get_world_size(moe_group) if moe_group is not None else 1 if (layer_idx not in config.mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 ): - self.mlp = Qwen3MoeSparseMoeBlock(config) + self.mlp = ( + QuickAccessMoEFactory.create_from_model_name( + pretrained_config=config, + expert_class=Qwen3MoeMLP, + gate_activation="softmax", + expert_activation="silu", + train_topk_method="greedy", + inference_topk_method="greedy", + drop_tokens=False, + ) + if expert_parallel_degree > 1 + else Qwen3MoeSparseMoeBlock(config) + ) else: # num_experts == 0 or this layer is not sparse layer self.mlp = Qwen3MoeMLP(config) @@ -513,20 +533,46 @@ def make_base_actions(): for k in LAYER_ROWWISE } ) + try: + moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group() + except Exception: + moe_group = None + expert_parallel_degree = dist.get_world_size(moe_group) if moe_group is not None else 1 + # TODO: merge disable_ffn_model_parallel and expert_parallel_degree + if expert_parallel_degree <= 1: + # # if disable_ffn_model_parallel is True, disable expert layer tp plan + # if not config.disable_ffn_model_parallel: + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial( + fn, is_column=True + ) + for e in range(config.num_experts) + for k in EXPERT_LAYER_COLWISE + } + ) + actions.update( + { + f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial( + fn, is_column=False + ) + for e in range(config.num_experts) + for k in EXPERT_LAYER_ROWWISE + } + ) actions.update( { - f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial(fn, is_column=True) - for e in range(config.num_experts) - for k in EXPERT_LAYER_COLWISE + f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.{k}": partial(fn, is_column=False) + for k in EXPERT_LAYER_ROWWISE } ) actions.update( { - f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.experts.{e}.{k}": partial(fn, is_column=False) - for e in range(config.num_experts) - for k in EXPERT_LAYER_ROWWISE + f"{cls.base_model_prefix}.layers.{layer_idx}.mlp.{k}": partial(fn, is_column=True) + for k in EXPERT_LAYER_COLWISE } ) + # bias if config.attention_bias: actions.update( @@ -535,7 +581,6 @@ def make_base_actions(): for b in BIAS_KEYS } ) - return actions mappings = make_base_actions()