Skip to content

Commit 9adce24

Browse files
committed
add dsv3 from paddlenlp-sft
1 parent 2123476 commit 9adce24

23 files changed

+1326
-2365
lines changed

examples/run_finetune.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,10 @@ def main():
140140
model_config.max_sequence_length = training_args.max_seq_len
141141
model_config.num_nextn_predict_layers = model_args.num_nextn_predict_layers
142142
model_config._attn_implementation = model_args.attn_impl
143+
model_config.gradient_accumulation_steps = training_args.gradient_accumulation_steps
144+
model_config.using_flex_token = model_args.using_flex_token
145+
model_config.using_fake_gate = model_args.using_fake_gate
146+
model_config.aux_loss_alpha = model_args.aux_loss_alpha
143147
logger.info(f"Final model config: {model_config}")
144148
logger.info("Creating model")
145149

@@ -278,13 +282,16 @@ def neft_post_hook(module, input, output):
278282
training_args.logging_steps = int(training_args.max_steps / training_args.num_train_epochs)
279283

280284
callbacks = []
285+
281286
if getattr(model_config, "topk_method", None) == "noaux_tc":
282-
callbacks += [MoECorrectionBiasAdjustCallback(lr=0)]
287+
# deepseek_v3 finetune do not update the bias, so set lr to 0.0
288+
callbacks += [MoECorrectionBiasAdjustCallback(lr=0.0)]
283289

284290
if training_args.use_expert_parallel:
285291
callbacks += [MoeExpertsGradScaleCallback(training_args)]
286292

287-
print("callbacks:", callbacks, flush=True)
293+
logger.info(f"callbacks: {callbacks}")
294+
288295
trainer = SFTTrainer(
289296
model=model,
290297
args=training_args,
@@ -295,6 +302,7 @@ def neft_post_hook(module, input, output):
295302
data_collator=data_collator,
296303
do_generation=data_args.eval_with_do_generation,
297304
data_args=data_args,
305+
callbacks=callbacks,
298306
)
299307
trainable_parameters = [
300308
p for p in model.parameters() if not p.stop_gradient or ("quantization_linear" in p.name and "w_1" in p.name)

paddleformers/nn/linear.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import contextlib
16+
1517
import paddle.nn as nn
1618
from paddle.incubate.nn import FusedLinear
1719

@@ -22,6 +24,7 @@
2224
RowParallelLinear,
2325
RowSequenceParallelLinear,
2426
)
27+
from ..transformers.model_utils import dtype_guard
2528
from .general import GeneralInterface
2629

2730
__all__ = ["Linear"]
@@ -51,6 +54,12 @@ def create(
5154
input_is_parallel: bool = True,
5255
fuse_matmul_bias: bool = False,
5356
):
57+
def linear_type_gaurd():
58+
if config.use_fp8:
59+
return dtype_guard("float8_e4m3fn")
60+
else:
61+
return contextlib.nullcontext()
62+
5463
if linear_type is None and config is None:
5564
raise ValueError("linear_type or config must be specified")
5665

@@ -59,7 +68,9 @@ def create(
5968

6069
linear_cls = self._global_mapping[linear_type]
6170
kwargs = self.get_linear_kwargs(linear_type, has_bias, gather_output, input_is_parallel, fuse_matmul_bias)
62-
return linear_cls(in_features=in_features, out_features=out_features, weight_attr=weight_attr, **kwargs)
71+
72+
with linear_type_gaurd():
73+
return linear_cls(in_features=in_features, out_features=out_features, weight_attr=weight_attr, **kwargs)
6374

6475
@classmethod
6576
def get_linear_type(self, config: PretrainedConfig, tp_plan=None):

paddleformers/nn/mlp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ def __init__(
4545
self.act_type = config.get("hidden_act", "silu")
4646
self.act_fn = ACT2FN[self.act_type]
4747
self.fuse_up_gate = fuse_up_gate
48+
self.is_moe = kwargs.get("is_moe", False)
49+
linear_type = None
50+
if self.is_moe:
51+
linear_type = "default"
4852

4953
if self.fuse_up_gate:
5054
setattr(
@@ -57,6 +61,7 @@ def __init__(
5761
config=config,
5862
fuse_matmul_bias=config.fuse_linear,
5963
tp_plan="colwise",
64+
linear_type=linear_type,
6065
),
6166
)
6267
self.up_gate_proj = getattr(self, gate_up_proj_name)
@@ -72,6 +77,7 @@ def __init__(
7277
config=config,
7378
fuse_matmul_bias=config.fuse_linear,
7479
tp_plan="colwise",
80+
linear_type=linear_type,
7581
),
7682
)
7783
self.gate_proj = getattr(self, gate_proj_name)
@@ -87,6 +93,7 @@ def __init__(
8793
config=config,
8894
fuse_matmul_bias=config.fuse_linear,
8995
tp_plan="colwise",
96+
linear_type=linear_type,
9097
),
9198
)
9299
self.up_proj = getattr(self, up_proj_name)
@@ -102,6 +109,7 @@ def __init__(
102109
config=config,
103110
fuse_matmul_bias=config.fuse_linear,
104111
tp_plan="rowwise",
112+
linear_type=linear_type,
105113
),
106114
)
107115
self.down_proj = getattr(self, down_proj_name)

paddleformers/nn/pp_model.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -507,12 +507,28 @@ class GeneralModelForCausalLMPipe(PipelinePretrainedModel, PipelineLayer):
507507
transpose_weight_keys = None
508508
_embed_cls = None
509509
_rotary_emb_cls = None
510+
_mtp_layer_pipe_cls = None
511+
_embedding_pipe_cls = None
512+
_decoder_layer_pipe_cls = None
513+
_criterion_pipe_cls = None
514+
_lmhead_pipe_cls = None
515+
_rms_norm_pipe_cls = None
510516

511517
def __init__(self, config: PretrainedConfig, **kwargs):
512518
# dynamic inherit DecoderLayer
513519
if self._decoder_layer_cls is None:
514520
raise ValueError("_decoder_layer_cls must be set before init.")
515-
DecoderLayerPipe = make_decoder_layer_pipe(self._decoder_layer_cls)
521+
522+
EmbeddingPipeCls = self._embedding_pipe_cls if self._embedding_pipe_cls is not None else Embedding
523+
524+
if self._decoder_layer_pipe_cls is None:
525+
DecoderLayerPipe = make_decoder_layer_pipe(self._decoder_layer_cls)
526+
else:
527+
DecoderLayerPipe = self._decoder_layer_pipe_cls
528+
529+
LMHeadPipeCls = self._lmhead_pipe_cls if self._lmhead_pipe_cls is not None else LMHeadPipe
530+
MTPLayerPipeCls = self._mtp_layer_pipe_cls if self._mtp_layer_pipe_cls is not None else None
531+
RMSNormPipeCls = self._rms_norm_pipe_cls if self._rms_norm_pipe_cls is not None else RMSNormPipe
516532

517533
new_initializer_range = math.sqrt(0.3333 / config.hidden_size)
518534
logger.info(f"change initializer-range from {config.initializer_range} to {new_initializer_range}")
@@ -559,7 +575,7 @@ def __init__(self, config: PretrainedConfig, **kwargs):
559575
else:
560576
self.add_sequential_layer(
561577
LayerDesc(
562-
EmbeddingPipe, config=config, embed_cls=self._embed_cls, rotary_emb_cls=self._rotary_emb_cls
578+
EmbeddingPipeCls, config=config, embed_cls=self._embed_cls, rotary_emb_cls=self._rotary_emb_cls
563579
),
564580
"model",
565581
)
@@ -573,6 +589,12 @@ def __init__(self, config: PretrainedConfig, **kwargs):
573589
),
574590
f"model.layers.{i}",
575591
)
592+
for i in range(config.num_nextn_predict_layers):
593+
if MTPLayerPipeCls is not None:
594+
self.add_sequential_layer(
595+
LayerDesc(MTPLayerPipeCls, config=config, layer_idx=config.num_hidden_layers + i),
596+
f"model.layers.{config.num_hidden_layers + i}",
597+
)
576598
for i in range(config.add_tail_layers):
577599
self.add_sequential_layer(
578600
LayerDesc(
@@ -582,22 +604,22 @@ def __init__(self, config: PretrainedConfig, **kwargs):
582604
)
583605

584606
self.add_sequential_layer(
585-
LayerDesc(RMSNormPipe if config.use_rmsnorm else LayerNormPipe, config=config),
607+
LayerDesc(RMSNormPipeCls if config.use_rmsnorm else LayerNormPipe, config=config),
586608
"model.norm",
587609
)
588610

589611
if config.tie_word_embeddings:
590612
self.add_sequential_layer(
591613
SharedLayerDesc(
592614
"model_shared_weight",
593-
LMHeadPipe,
615+
LMHeadPipeCls,
594616
shared_weight_attr="embedding_weight",
595617
config=config,
596618
),
597619
"lm_head",
598620
)
599621
else:
600-
self.add_sequential_layer(LayerDesc(LMHeadPipe, config=config), "lm_head")
622+
self.add_sequential_layer(LayerDesc(LMHeadPipeCls, config=config), "lm_head")
601623
recompute_interval = 0
602624

603625
seg_method = config.pp_seg_method if hasattr(config, "pp_seg_method") else "layer:DecoderLayer|EmptyLayer"
@@ -630,10 +652,12 @@ def __init__(self, config: PretrainedConfig, **kwargs):
630652
)
631653

632654
def get_loss_fn(self, config):
655+
CriterionPipeCls = self._criterion_pipe_cls if self._criterion_pipe_cls is not None else CriterionLayerPipe
656+
633657
if config.get("dpo_config", None) is not None:
634-
loss_fn = CriterionLayerPipe(config, use_infohub=True)
658+
loss_fn = CriterionPipeCls(config, use_infohub=True)
635659
else:
636-
loss_fn = CriterionLayerPipe(config)
660+
loss_fn = CriterionPipeCls(config)
637661

638662
return loss_fn
639663

0 commit comments

Comments
 (0)