Skip to content

Conversation

Difers
Copy link

@Difers Difers commented Sep 11, 2025

Before submitting

  • Lint code. If there are lint issues, please format the code first.
# Install and register `pre-commit` in the project folder
pip install pre-commit && pre-commit install

# Process previous code files separately
pre-commit run --file XXXX.py
  • Add test cases into tests folder. If there are codecov issues, please add tests cases first.

PR types

New features

PR changes

Models

Description

迁移PaddleNLP dsv3 Modle 至 PaddleFormers

当前PR

  • 迁移paddlenlp中deepseek v3组网至formers,保证基本功能正常,并基本符合formers组网规范

后续TODO

  • 添加单卡moe组网,以及对比hf transformers修改组网规范及与hf transformers精度对齐
  • 添加组网单测
  • 添加模型配置和全流程使用文档
  • 添加moe过滤padding token功能

新增部分特性与修改相关PR,供参考

4K 下收敛验证

ec8e210edc880e71be1330f30afdb436

Copy link

paddle-bot bot commented Sep 11, 2025

Thanks for your contribution!

@Difers Difers force-pushed the add_dsv3_from_nlp branch 2 times, most recently from 7cc21a4 to a0539fe Compare September 11, 2025 08:22
@codecov-commenter
Copy link

codecov-commenter commented Sep 11, 2025

Codecov Report

❌ Patch coverage is 16.78161% with 362 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@03b533c). Learn more about missing BASE report.

Files with missing lines Patch % Lines
paddleformers/transformers/deepseek_v2/modeling.py 15.50% 267 Missing ⚠️
paddleformers/trainer/utils/offload_optimizer.py 0.00% 43 Missing ⚠️
paddleformers/trainer/trainer.py 11.76% 15 Missing ⚠️
paddleformers/nn/pp_model.py 30.00% 14 Missing ⚠️
paddleformers/transformers/deepseek_v3/modeling.py 65.21% 8 Missing ⚠️
paddleformers/trainer/trainer_utils.py 14.28% 6 Missing ⚠️
paddleformers/transformers/moe_layer.py 0.00% 6 Missing ⚠️
paddleformers/trl/model_config.py 0.00% 2 Missing ⚠️
paddleformers/transformers/moe_gate.py 0.00% 1 Missing ⚠️

❌ Your patch status has failed because the patch coverage (16.78%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #2593   +/-   ##
==========================================
  Coverage           ?   29.89%           
==========================================
  Files              ?      308           
  Lines              ?    53980           
  Branches           ?        0           
==========================================
  Hits               ?    16136           
  Misses             ?    37844           
  Partials           ?        0           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.


model_class = AutoModelForCausalLMPipe

model_config.using_flex_token = model_args.using_flex_token
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

设置model_config的逻辑,这部分都需要挪到202行前面。

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done~

if training_args.use_expert_parallel:
callbacks += [MoeExpertsGradScaleCallback(training_args)]

print("callbacks:", callbacks, flush=True)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

换成log.info,正式代码不能出现print

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done~

attn_mask_startend_row_indices = attn_mask_startend_row_indices[
:,
:,
: -self.config.num_nextn_predict_layers,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

数据流支持之后,这里还需要截断嘛?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_ids,attn_mask_startend_row_indices等数据流输出时维度加了nextn_predict_layers,仍需截断

"unified_checkpoint": true,
"use_flash_attention": true,
"flash_mask": true,
"using_fake_gate": true,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fake gate改成false

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果不需要fake_gate是否可以删除

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

megatraon也有fake gate,moe场景都需要fake gate,测性能用的。

"expert_parallel_degree": 16,
"continue_training": true,
"pipeline_parallel_config": "enable_delay_scale_loss disable_partial_send_recv disable_batch_p2p_comm",
"tensor_parallel_config": "enable_delay_scale_loss",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tp补充"tensor_parallel_config": "sync_param sync_grad"

"do_eval": false,
"disable_tqdm": true,
"use_expert_parallel": true,
"expert_parallel_degree": 8,
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4K的分布式策略是sharding16ep16pp8

@Difers Difers force-pushed the add_dsv3_from_nlp branch 4 times, most recently from 6f20001 to d5a1f8e Compare September 24, 2025 13:20
"continue_training": true,
"pipeline_parallel_config": "enable_delay_scale_loss disable_partial_send_recv disable_batch_p2p_comm",
"tensor_parallel_config": "enable_delay_scale_loss",
"load_best_model_at_end": true,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"load_best_model_at_end": true, "metric_for_best_model": "loss",可以去掉了

"sharding": "stage1",
"unified_checkpoint": true,
"use_flash_attention": true,
"flash_mask": true,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"use_flash_attention": true,
"flash_mask": true 这两个开关去掉

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"use_flash_attention": true,
"flash_mask": true,
"using_fake_gate": true,
"using_flex_token": true,
Copy link
Collaborator

@lugimzzz lugimzzz Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

组网里只保留deepep这套就行,把all2all的删掉,无需再写flex_token判断

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么不需要alltoall版本了?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这套MOE全参训练有hang住问题,并且这套效率不高不再考虑维护

"flash_mask": true,
"using_fake_gate": true,
"using_flex_token": true,
"use_fused_rms_norm": true,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -0,0 +1,62 @@
{
"model_name_or_path": "/root/paddlejob/tmpspace/huggingface_model/huggingface/deepseek-ai/DeepSeek-V3-bf16/",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

直接写模型名字 opensourcerelease/DeepSeek-V3-bf16

@@ -0,0 +1,62 @@
{
"model_name_or_path": "/root/paddlejob/tmpspace/huggingface_model/huggingface/deepseek-ai/DeepSeek-V3-bf16/",
"dataset_name_or_path": "/root/paddlejob/tmpspace/chenzhichao/PaddleNLP-SFT/llm/en_data",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

注意路径写一个通用路径,全局注意修改

from paddle.incubate.nn.functional import fused_rms_norm_ext

from ..generation.configuration_utils import PretrainedConfig
from ..transformers.llama import fusion_ops
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是可以去除fusion_ops?

from ..generation.configuration_utils import PretrainedConfig
from ..transformers.llama import fusion_ops
from ..utils.log import logger
from ..utils.tools import get_env_device
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么需要import get_env_device

),
)
self.gate_proj = getattr(self, gate_proj_name)
def linear_type_gaurd():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个FP8处理是不是放在paddleformers.nn.linear 更合适,如果想要控制moe部分不为TP linear,建议create MLP传入config.tensor_parallel_degree = 1
https://github.com/PaddlePaddle/PaddleFormers/blob/develop/paddleformers/transformers/ernie4_5_moe/modeling.py#L134

"unified_checkpoint": true,
"use_flash_attention": true,
"flash_mask": true,
"using_fake_gate": true,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果不需要fake_gate是否可以删除

"sequence_parallel": true,
"tensor_parallel_output": true,
"amp_master_grad": true,
"sharding_parallel_config": "split_param",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

开sharding stage1 v2场景+ tensorwise_offload_optimizer 试过UC热启续接正常吗?

"amp_master_grad": true,
"sharding_parallel_config": "split_param",
"num_nextn_predict_layers": 1,
"convert_from_hf": true
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convert_from_hf 这个可以去掉了,默认是True

@@ -0,0 +1,60 @@
{
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@@ -0,0 +1,60 @@
{
"model_name_or_path": "/root/paddlejob/tmpspace/huggingface_model/huggingface/deepseek-ai/DeepSeek-V3-bf16/",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议针对具体模型可以写一个examples/best_practice/deepseek_v3_sft/... 放具体模型配置和全流程使用文档

model_config.num_nextn_predict_layers = model_args.num_nextn_predict_layers
model_config._attn_implementation = model_args.attn_impl
model_config.moe_subbatch_token_num = model_args.moe_subbatch_token_num
model_config.gradient_accumulation_steps = training_args.gradient_accumulation_steps
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个的作用?

param.split_axis = 0

def forward(self, hidden_states, tensor_parallel_output=None):
def forward(self, hidden_states, tensor_parallel_output=None, gather_hidden_states=True):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

虽然没啥关系,但为什么需要改这里?

group=self._hcg.get_pipe_parallel_group(),
)

# logger.info(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要的部分就注释掉

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

全局注意这个问题

__all__ = []


class MoEHybridParallelClipGrad:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MoEHybridParallelClipGrad 这个的作用?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个是sharding-ep场景正确计算global norm的 clipgrad方法,否则计算的global_norm是错的。原来的dp-moe和mp-moe都可以用原来的。


if getattr(model_config, "topk_method", None) == "noaux_tc":
callbacks += [MoECorrectionBiasAdjustCallback(lr=0)]
# deepseek_v3 finetune do not update the bias, so set lr to 0.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deepseek v3有用到这个策略吗?看起来配置没打开?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看一下上面那个if判断,这个是通过model_config来开启的。

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看topk_method默认是gready?什么情况需要开noaux_tc?

if self._decoder_layer_cls is None:
raise ValueError("_decoder_layer_cls must be set before init.")
DecoderLayerPipe = make_decoder_layer_pipe(self._decoder_layer_cls)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加个判断if config.num_nextn_predict_layers > 0 and _mtp_layer_pipe_cls is None: 需要定义 _mtp_layer_pipe_cls

hidden_states, _, _, _, _ = parse_args(args)
hidden_states = super().forward(hidden_states)
return hidden_states

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这段建议暂时写在deepseek组网内, 传入一个_rmsnorm_pipe_cls来。暂时不确定这种写法是否适用于其他模型

[batch_size, sequence_length, vocab_size]
representing unnormalized log probabilities for each token
"""
if self.config.num_nextn_predict_layers > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同rmsnorm

group=self._hcg.get_pipe_parallel_group(),
)

# logger.info(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

全局注意这个问题

return self._dygraph_clip(params_grads)


class MoEHybridParallelOptimizer(HPBase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在什么场景需要使用MoEHybridParallelOptimizer?

from ...nn.mlp import MLP as DeepseekV2MLP
from ...nn.norm import Norm as GeneralNorm
from ...nn.pp_model import EmbeddingPipe, GeneralModelForCausalLMPipe, parse_args

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认一下paddle 3.2以后版本 fused_rotary_position_embedding是否存在

from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from ...nn.criterion.interface import CriterionLayer
from ...nn.embedding import Embedding as GeneralEmbedding
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

仿照qwen2添加单测

self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.use_rmsnorm = use_rmsnorm
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不需要在config中新增use_rmsnorm,创建rmsnorm直接传入type就行 https://github.com/PaddlePaddle/PaddleFormers/blob/develop/paddleformers/transformers/ernie4_5_moe/modeling.py#L320

norm_topk_prob=False,
scoring_func="softmax",
aux_loss_alpha=0.001,
aux_loss_alpha=0.0001,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对照一下deepseekv2& v3模型中transformers config和paddleformers config差距,列一下额外多出的config的作用

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

论文里这个系数是0.0001,之前就填错了。

from ...nn.mlp import MLP as DeepseekV2MLP
from ...nn.norm import Norm as GeneralNorm
from ...nn.pp_model import EmbeddingPipe, GeneralModelForCausalLMPipe, parse_args

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

确认哪些没用到,或paddle3.2以后版本已经有了不需要try
image

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image 对llama依赖是否能够去除

@lugimzzz
Copy link
Collaborator

解决一下CI和codestyle问题,提交代码前需要pre-commit install

self.vocab_size = config.vocab_size
self.lm_head = DeepseekV2LMHead(config)
self.criterion = DeepseekV2PretrainingCriterion(config)
self.lm_head = GeneralLMHead(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deepseek v2没有修改base_model_prefix,已经其他部分对应修改

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_input_embeddings & set_input_embeddings & get_output_embeddings ...等等都可以直接用父类PretrainedModel中的函数

)
using_flex_token: bool = field(default=False, metadata={"help": "Whether to use deepep moe_layer"})
using_fake_gate: bool = field(default=False, metadata={"help": "Whether to fake gate"})
moe_subbatch_token_num: int = field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经加到现有PR里面了可以去掉

"down_proj",
"gate",
"eh_proj",
"lm_head",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉lm_head,现在的lm head都是[vocab_size,hidden_states],不需要transpose了


# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.deepseek_v3(
outputs = self.model(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeepseekV3ForCausalLM 不开PP缩层验证MTP过正常吗?

f"Implementation of fused_rms_norm is not available on {get_env_device()}. Please install paddle_xpu to use this feature"
)
if self.config.use_fused_rms_norm:
if get_env_device() == "xpu":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeepseekV2RMSNorm还需要用到吗?可以直接使用通用


if self.using_flex_token:
scores, routing_map, exp_counts, l_aux, l_zloss = self.topkgating_nodrop(scores)
with paddle.no_grad():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeepseekV2MoE建议保留一个单卡组网对照transformers和deepep版本的EP并行

"DeepseekV2ForSequenceClassification",
"DeepseekV2Model",
"DeepseekV2PretrainedModel",
"DeepseekV2ForCausalLMPipe",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.kv_b_proj = ColumnParallelLinear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), has_bias=False, gather_output=True)
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, self.hidden_size, has_bias=config.attention_bias, input_is_parallel=False)
self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank, use_sequence_parallel=False)
self.q_proj = GeneralLinear.create(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linear_dtype_gaurd() 写在GeneralLinear用config控制

pg.allreduce(param.main_grad).wait()
else:
pg.allreduce(param.grad).wait()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grad_allreduce_hook的作用是是什么

)


class DeepseekV2MTPLayerPipe(DeepseekV2MTPLayer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

关于pp的东西比较多DeepseekV2MTPLayerPipe、DeepseekV2EmbeddingPipe、DeepseekV2DecoderLayerPipe这些类写在modeling_pp.py里import进来吧

@lugimzzz
Copy link
Collaborator

paddleformers组网相比与paddlenlp组网有些不同,在写组网尽量复用paddleformers现有模参考paddleformers 中qwen2&ernie4.5组网和transformers中deepseek v3的写法,除了EP训练必须的配置(这些配置可以列一下,看是否其他模型也需要进行通用化),冗余代码都删掉,
https://github.com/PaddlePaddle/PaddleFormers/blob/develop/paddleformers/transformers/qwen2/modeling.py


self.enorm = DeepseekV2RMSNorm(config)
self.hnorm = DeepseekV2RMSNorm(config)
self.enorm = GeneralNorm.create(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议传norm_type指明norm类型

@Difers Difers force-pushed the add_dsv3_from_nlp branch 2 times, most recently from a1631f3 to 9adce24 Compare September 28, 2025 09:40
@Difers Difers closed this Sep 28, 2025
@Difers Difers reopened this Sep 28, 2025
@Difers Difers closed this Sep 29, 2025
@Difers Difers reopened this Sep 29, 2025
@Difers
Copy link
Author

Difers commented Sep 29, 2025

/re-run all-failed


if getattr(model_config, "topk_method", None) == "noaux_tc":
callbacks += [MoECorrectionBiasAdjustCallback(lr=0)]
# deepseek_v3 finetune do not update the bias, so set lr to 0.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看topk_method默认是gready?什么情况需要开noaux_tc?

self.kv_b_proj = Linear(config.kv_lora_rank, self.num_heads * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), bias_attr=False)
self.o_proj = Linear(self.num_heads * self.v_head_dim, self.hidden_size, bias_attr=config.attention_bias)
self.kv_a_layernorm = DeepseekV2RMSNorm(config=config, hidden_size=config.kv_lora_rank)
self.q_a_proj = GeneralLinear.create(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在不支持FP8,linear_dtype_guard先去了?


def get_input_embeddings(self):
return self.deepseek_v3.embed_tokens
return self.model.embed_tokens
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些函数PretrainedModel已经有了不需要重复定义,deepseek v2也同步修改一下https://github.com/PaddlePaddle/PaddleFormers/blob/develop/paddleformers/transformers/model_utils.py#L1455

pp_seg_method: Optional[str] = field(
default="layer:DecoderLayer|EmptyLayer", metadata={"help": "PP Segmentation Method"}
)
using_fake_gate: bool = field(default=False, metadata={"help": "Whether to fake gate"})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using_fake_gate和aux_loss_alpha都是moe模型可以通用使用的参数建议直接添加到LlmMetaConfig&PretrainedConfig,具体看PaddleFormers 贡献模型示例文档2.1.1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LlmMetaConfig&PretrainedConfig 在这两个添加后就不需要在model_config.py & run_finetune.py新增代码了


if getattr(model_config, "topk_method", None) == "noaux_tc":
callbacks += [MoECorrectionBiasAdjustCallback(lr=0)]
# deepseek_v3 finetune do not update the bias, so set lr to 0.0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个glm4.5也会用到,注释不要指定模型

return f"hidden_size={self.hidden_size}, dtype={self.weight.dtype}"


class DeepseekV2RotaryEmbedding(nn.Layer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

现在rotaryembedding使用的方式是预先计算position_embeddings,然后再传入模型组网中。现在pp_model laye和layer间传递参数解析也是默认有position_embeddings,确认一下当前写法是否会有冗余或隐藏问题,改成新的写法
https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py#L535

if len(args) == 5:

Uploading image.png…

from ...utils.tools import get_env_device
from ..activations import ACT2FN
from ..conversion_utils import StateDictNameMapping, init_name_mappings
from ..llama import fusion_ops
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"DeepseekV2ForSequenceClassification",
"DeepseekV2Model",
"DeepseekV2PretrainedModel",
"DeepseekV2ForCausalLMPipe",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

冗余的函数可以删除,比如get_triangle_upper_mask、assign_kv_heads等等
可对比transformers中实现
https://github.com/huggingface/transformers/blob/main/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py#L386

from ..utils import device_guard
from . import fp8_linear as linear_utils
from .configuration import DeepseekV2Config
from .fp8_linear import Linear
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fp8没验证先去掉

from ..activations import ACT2FN
from ..conversion_utils import StateDictNameMapping, init_name_mappings
from ..llama import fusion_ops
from ..llama.modeling import get_use_casual_mask
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉依赖llama相关内容

@lugimzzz
Copy link
Collaborator

需要进一步修改代码,符合paddleformers代码规范。如本PR暂不完成,备注留在下个pr修改内容

@Difers
Copy link
Author

Difers commented Sep 30, 2025

需要进一步修改代码,符合paddleformers代码规范。如本PR暂不完成,备注留在下个pr修改内容

目前很多问题其实来自于需要将由paddlenlp迁移过来的组网需要再和hf transformers的组网做对齐;这一部分已列入todo list,再下一个pr再统一修改

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants