Skip to content

Conversation

@hushenwei2000
Copy link
Contributor

@hushenwei2000 hushenwei2000 commented Sep 28, 2025

PR types

New features

PR changes

Models

Description

这个 PR 添加了一个通用的 MoE Layer 模块,采用模块化设计。
每个模型中的 MoE 层都可以替换为该通用模块,只需要在 moe_config.py 中配置该模型使用的激活函数、TopK计算方法等信息。

测试

已验证标准 Qwen3MoE 模型 SFT 的以下配置:

模型 loss 对齐 GSM8K 推理精度(预期 85.4%)
TP4PP2SP 85.4%
EP4PP2 85.4%
EP4TP4PP2SP 88.7%
EP4TP2PP2SP 暂未验证

未完成部分

  • AllToAll 通信方式的 Expert Parallel
  • 自定义 Loss 系统
  • aux_loss_weight、z_loss_weight、expert_dropout 作为 yaml 设置项

@codecov-commenter
Copy link

codecov-commenter commented Sep 28, 2025

Codecov Report

❌ Patch coverage is 26.51297% with 510 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@d80b68d). Learn more about missing BASE report.

Files with missing lines Patch % Lines
paddleformers/nn/moe_deepep/moe_gate.py 14.35% 185 Missing ⚠️
paddleformers/nn/moe_deepep/modular_moe_layer.py 15.97% 163 Missing ⚠️
paddleformers/nn/moe_deepep/moe_communication.py 22.22% 70 Missing ⚠️
paddleformers/nn/moe_deepep/moe_loss.py 49.58% 61 Missing ⚠️
paddleformers/nn/moe_deepep/moe_loss_instance.py 40.74% 16 Missing ⚠️
paddleformers/transformers/qwen3_moe/modeling.py 52.94% 8 Missing ⚠️
paddleformers/nn/moe_deepep/moe_factory.py 54.54% 5 Missing ⚠️
paddleformers/nn/moe_deepep/moe_expert.py 81.81% 2 Missing ⚠️

❌ Your patch status has failed because the patch coverage (26.51%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #2702   +/-   ##
==========================================
  Coverage           ?   28.74%           
==========================================
  Files              ?      343           
  Lines              ?    57154           
  Branches           ?        0           
==========================================
  Hits               ?    16427           
  Misses             ?    40727           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@hushenwei2000 hushenwei2000 changed the title [Unified MoE]: Add Unified MoE Interface with EP Support [Unified MoE Layer]: Add Unified MoE Layer with DeepEP EP Support (Test on Qwen3MoE) Oct 31, 2025
@hushenwei2000 hushenwei2000 changed the title [Unified MoE Layer]: Add Unified MoE Layer with DeepEP EP Support (Test on Qwen3MoE) [Unified MoE Layer]: Add MoE Layer with DeepEP EP Support; Add Qwen3MoE EP Oct 31, 2025
@ZeyuChen ZeyuChen requested a review from Copilot November 1, 2025 12:18
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds a modular MoE (Mixture of Experts) implementation to PaddlePaddle, integrating it with the Qwen3-MoE model. The changes introduce a flexible, extensible MoE framework with customizable gates, experts, communication strategies, and loss functions.

  • Adds a new modular MoE layer system supporting Expert Parallel (EP) mode
  • Integrates the new MoE implementation with Qwen3-MoE model architecture
  • Implements configurable loss functions and combiners for MoE training

Reviewed Changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 54 comments.

Show a summary per file
File Description
paddleformers/transformers/qwen3_moe/modeling.py Integrates QuickAccessMoEFactory and updates tensor parallel mappings to support fused attention and EP mode
paddleformers/nn/moe_deepep/moe_loss_instance.py Defines global loss registry instance and custom loss functions
paddleformers/nn/moe_deepep/moe_loss.py Implements flexible loss system with multiple loss types and combiners
paddleformers/nn/moe_deepep/moe_gate.py Implements standard and flexible MoE gate mechanisms with routing strategies
paddleformers/nn/moe_deepep/moe_factory.py Factory pattern for creating MoE layers from model configs
paddleformers/nn/moe_deepep/moe_expert.py Expert network implementations for MoE layers
paddleformers/nn/moe_deepep/moe_config.py Configuration dictionary for different MoE model types
paddleformers/nn/moe_deepep/moe_communication.py Communication strategies for Expert Parallel training
paddleformers/nn/moe_deepep/modular_moe_layer.py Main modular MoE layer implementation
paddleformers/nn/moe_deepep/init.py Module initialization and lazy imports

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

from ...nn.linear import Linear as GeneralLinear
from ...nn.lm_head import LMHead as GeneralLMHead
from ...nn.mlp import MLP
from ...nn.moe_deepep.moe_factory import QuickAccessMoEFactory

This comment was marked as abuse.

Comment on lines 215 to 219
def _probs_drop_policy(
self,
scores: torch.Tensor,
capacity: int,
) -> torch.Tensor:
Copy link

Copilot AI Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect type annotation. Should use paddle.Tensor instead of torch.Tensor to match the PaddlePaddle framework being used.

Copilot uses AI. Check for mistakes.
2. Its score for that expert is among the top 'capacity' scores for that expert.
Args:
scores (torch.Tensor): [num_tokens, num_total_experts].
Copy link

Copilot AI Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Documentation refers to torch.Tensor but should reference paddle.Tensor to match the framework being used.

Copilot uses AI. Check for mistakes.
(Not strictly used here, but good practice to include).
Returns:
torch.Tensor: [num_tokens, num_total_experts] boolean mask (converted to float).
Copy link

Copilot AI Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return type documentation refers to torch.Tensor but should reference paddle.Tensor to match the framework being used.

Copilot uses AI. Check for mistakes.

# --- Step 1: Find the 'capacity' best tokens for *each* expert ---

# Use torch.topk along dim=0 (the token dimension) to find the indices
Copy link

Copilot AI Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment incorrectly mentions torch.topk but the actual implementation uses paddle.topk.

Copilot uses AI. Check for mistakes.

def __call__(self, losses: Dict[str, paddle.Tensor], configs: Dict[str, LossConfig]) -> paddle.Tensor:
"""组合多个损失函数"""
...
Copy link

Copilot AI Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This statement has no effect.

Copilot uses AI. Check for mistakes.
@lugimzzz
Copy link
Collaborator

lugimzzz commented Nov 2, 2025

# 安装pre-commit确保codestyle没有问题
pre-commit install

config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen3MoeSparseMoeBlock(config)
self.mlp = QuickAccessMoEFactory.create_from_model_name(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要删除原本的Qwen3MoeSparseMoeBlock,通过是否打开EP来选择使用哪个类

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加 if expert_parallel_degree > 1 再使用


LAYER_ROWWISE = ["self_attn.o_proj.weight"]

FUSE_LAYER_COLWISE = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要增加这个self_attn.qkv_proj.weight

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

相关内容都删了

"gate_proj.weight",
]

FUSE_EXPERT_LAYER_COLWISE = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

"self_attn.v_proj.bias",
]

FUSE_BIAS_KEYS = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

if expert_parallel_degree <= 1:
# # if disable_ffn_model_parallel is True, disable expert layer tp plan
# if not config.disable_ffn_model_parallel:
if not config.fuse_attention_ffn:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没有fuse_attention_ffn开关为什么要添加?这些?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

相关内容都删了

elif self.expert_parallel_degree > 1 and self.tensor_parallel_degree >= 1:
routed_expert_pretrained_config.tensor_parallel_degree = 1

# self.experts = nn.LayerList(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果不需要的代码直接删除


self.experts = nn.LayerList(
[
MLP(config=routed_expert_pretrained_config, intermediate_size=pretrained_config.moe_intermediate_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不是所有模型的MLP都可以通用的,建议初始化的时候从模型里传入MLP 类,提高自由度

else:
self.communication = DeepEPMoECommunication()

# self.is_dummy_moe = False if self.expert_parallel_degree > 1 else True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注释掉的代码删除

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

全局都注意这个事情

from paddle import nn
from paddle.incubate.nn.functional import swiglu as fused_swiglu

from ...nn.mlp import MLP
Copy link
Collaborator

@lugimzzz lugimzzz Nov 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在复现单卡模型的时候都会复现对应的expert,建议直接复用组网中的expert减少开发量和后续的维护成本(不如组网有改动,不需要改动一次组网里的expert,再改动一次这里的expert)。可以保留一个standard,但建议使用方法还是采用模型组网里本身使用的MLP

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在都改成复用组网的 expert 了


MOE_CONFIG = {
"qwen3_moe": {
"gate_activation": "softmax",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这块是否能够不要维护一个model_config,而使用,固定的参数直接在组网中传入,可变的参数通过config新增一些通用训练字段来控制
QuickAccessMoEFactory.create(
config=config,
gate_activation="softmax",
train_topk_method = "greedy",
inference_topk_method="greedy",
.....
)
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已按要求修改

@lugimzzz
Copy link
Collaborator

lugimzzz commented Nov 2, 2025

需要一个完整的文档对FlashMoe模块介绍

Co-authored-by: Copilot <[email protected]>
if self.custom_communication is not None:
self.communication = self.custom_communication
else:
if os.getenv("USE_DEEPEP", "1") == "0":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要用环境变量来控制使用的方式,选择通信方式,使用config之类来控制

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

增加了一个 ep_communication_type yaml 配置项

@hushenwei2000
Copy link
Contributor Author

已更新一版,复测 Qwen3MoE EP 通过

Copy link
Collaborator

@lugimzzz lugimzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lugimzzz lugimzzz merged commit 89cd9e8 into PaddlePaddle:develop Nov 3, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants