Skip to content

Conversation

danleifeng
Copy link
Contributor

@danleifeng danleifeng commented Sep 23, 2025

Before submitting

  • Lint code. If there are lint issues, please format the code first.
# Install and register `pre-commit` in the project folder
pip install pre-commit && pre-commit install

# Process previous code files separately
pre-commit run --file XXXX.py
  • Add test cases into tests folder. If there are codecov issues, please add tests cases first.

PR types

PR changes

Description

  1. 新增moe_subbatch_token_num优化,减少峰值显存, moe_subbatch_token_num值越小越省显存
  2. 新增pre_alloc_memory 减少碎片显存
  3. 修复 mlp.gate.weight input_layernorm.weight 训练前权重各卡不一致问题
    配置新增:
"moe_subbatch_token_num":4096,
"pre_alloc_memory": 60,

本地glm sft训练改动依赖pr:
sp:#2621
fused_loss修复:#2648

@codecov-commenter
Copy link

codecov-commenter commented Sep 23, 2025

Codecov Report

❌ Patch coverage is 5.88235% with 80 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@804acdd). Learn more about missing BASE report.

Files with missing lines Patch % Lines
paddleformers/transformers/glm4_moe/modeling.py 4.81% 79 Missing ⚠️
...ddleformers/transformers/glm4_moe/configuration.py 0.00% 1 Missing ⚠️

❌ Your patch status has failed because the patch coverage (5.88%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #2661   +/-   ##
==========================================
  Coverage           ?   29.63%           
==========================================
  Files              ?      311           
  Lines              ?    54682           
  Branches           ?        0           
==========================================
  Hits               ?    16204           
  Misses             ?    38478           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@danleifeng danleifeng changed the title 【GLM】add ep subbatch to reduce memory 【GLM】subbatch performance and weight bug fix Sep 24, 2025
saved_signal_path = os.path.join(output_dir, f"saved_signal_{dist.get_rank()}")

if self.args.unified_checkpoint and self.args.offload_optim:
if self.args.unified_checkpoint and (self.args.offload_optim or self.args.tensorwise_offload_optimizer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个会造成这保存optimizer参数的时候显存异常上涨,不建议这么写

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

os.environ["USE_CASUAL_MASK"] = "False"


def mock_offload_optimizer():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unified_checkpoint_config: ignore_merge_optimizer
optim: adamw_custom
tensorwise_offload_optimizer: True
训练添加上述也可以做到offload optimizer降低显存,暂时先把optimizer相关修改删除,这个PR可以先合一版

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删

moe_group=moe_group,
)
if hasattr(dist, "fleet") and dist.is_initialized() and expert_parallel_degree > 1:
# for p in self.experts.parameters():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要要的注释直接删掉

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

检查一下整个文件注释掉的代码

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Collaborator

@lugimzzz lugimzzz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@lugimzzz lugimzzz merged commit d841351 into PaddlePaddle:develop Sep 24, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants