30
30
get_last_checkpoint ,
31
31
set_seed ,
32
32
)
33
+ from paddleformers .trainer .trainer_callback import (
34
+ MoECorrectionBiasAdjustCallback ,
35
+ MoeExpertsGradScaleCallback ,
36
+ )
33
37
from paddleformers .transformers import (
34
38
AutoConfig ,
35
39
AutoModelForCausalLM ,
86
90
]
87
91
88
92
93
+ def mock_offload_optimizer ():
94
+ """
95
+ mock offload optimizer
96
+ """
97
+ try :
98
+ from paddleformers .trainer .utils .offload_optimizer import hack_offload_optimizer
99
+
100
+ hack_offload_optimizer ()
101
+ logger .warning ("hack_offload_optimizer called." )
102
+ except ImportError :
103
+ logger .warning ("hack_offload_optimizer is not imported" )
104
+
105
+
89
106
def main ():
90
107
parser = PdArgumentParser ((ModelConfig , DataConfig , SFTConfig ))
91
108
if len (sys .argv ) >= 2 and sys .argv [1 ].endswith (".json" ):
@@ -97,9 +114,18 @@ def main():
97
114
else :
98
115
model_args , data_args , training_args = parser .parse_args_into_dataclasses ()
99
116
117
+ if training_args .tensorwise_offload_optimizer :
118
+ mock_offload_optimizer ()
119
+
100
120
training_args .print_config (model_args , "Model" )
101
121
training_args .print_config (data_args , "Data" )
102
122
123
+ if training_args .pre_alloc_memory > 0 :
124
+ memory_size = int (training_args .pre_alloc_memory * 1024 * 1024 * 1024 )
125
+ x = paddle .empty ([memory_size ], dtype = paddle .uint8 )
126
+ logger .info (f"pre_alloc_memory size { x .shape } " )
127
+ del x
128
+
103
129
# Setup GPU & distributed training
104
130
paddle .set_device (training_args .device )
105
131
set_seed (seed = training_args .seed )
@@ -171,6 +197,7 @@ def main():
171
197
model_config .max_sequence_length = training_args .max_seq_len
172
198
model_config .num_nextn_predict_layers = model_args .num_nextn_predict_layers
173
199
model_config ._attn_implementation = model_args .attn_impl
200
+ model_config .gradient_accumulation_steps = training_args .gradient_accumulation_steps
174
201
logger .info (f"Final model config: { model_config } " )
175
202
logger .info ("Creating model" )
176
203
@@ -181,6 +208,11 @@ def main():
181
208
182
209
model_class = AutoModelForCausalLMPipe
183
210
211
+ model_config .using_flex_token = model_args .using_flex_token
212
+ model_config .using_fake_gate = model_args .using_fake_gate
213
+ model_config .moe_subbatch_token_num = model_args .moe_subbatch_token_num
214
+ model_config .aux_loss_alpha = model_args .aux_loss_alpha
215
+
184
216
if model_args .continue_training and not training_args .autotuner_benchmark :
185
217
model = model_class .from_pretrained (
186
218
model_args .model_name_or_path ,
@@ -309,6 +341,17 @@ def neft_post_hook(module, input, output):
309
341
training_args .logging_strategy = IntervalStrategy .STEPS
310
342
training_args .logging_steps = int (training_args .max_steps / training_args .num_train_epochs )
311
343
344
+ callbacks = []
345
+
346
+ if getattr (model_config , "topk_method" , None ) == "noaux_tc" :
347
+ # deepseek_v3 finetune do not update the bias, so set lr to 0.0
348
+ callbacks += [MoECorrectionBiasAdjustCallback (lr = 0.0 )]
349
+
350
+ if training_args .use_expert_parallel :
351
+ callbacks += [MoeExpertsGradScaleCallback (training_args )]
352
+
353
+ print ("callbacks:" , callbacks , flush = True )
354
+
312
355
trainer = SFTTrainer (
313
356
model = model ,
314
357
args = training_args ,
@@ -319,6 +362,7 @@ def neft_post_hook(module, input, output):
319
362
data_collator = data_collator ,
320
363
do_generation = data_args .eval_with_do_generation ,
321
364
data_args = data_args ,
365
+ callbacks = callbacks ,
322
366
)
323
367
trainable_parameters = [
324
368
p for p in model .parameters () if not p .stop_gradient or ("quantization_linear" in p .name and "w_1" in p .name )
0 commit comments