3030 get_last_checkpoint ,
3131 set_seed ,
3232)
33+ from paddleformers .trainer .trainer_callback import (
34+ MoECorrectionBiasAdjustCallback ,
35+ MoeExpertsGradScaleCallback ,
36+ )
3337from paddleformers .transformers import (
3438 AutoConfig ,
3539 AutoModelForCausalLM ,
8690]
8791
8892
93+ def mock_offload_optimizer ():
94+ """
95+ mock offload optimizer
96+ """
97+ try :
98+ from paddleformers .trainer .utils .offload_optimizer import hack_offload_optimizer
99+
100+ hack_offload_optimizer ()
101+ logger .warning ("hack_offload_optimizer called." )
102+ except ImportError :
103+ logger .warning ("hack_offload_optimizer is not imported" )
104+
105+
89106def main ():
90107 parser = PdArgumentParser ((ModelConfig , DataConfig , SFTConfig ))
91108 if len (sys .argv ) >= 2 and sys .argv [1 ].endswith (".json" ):
@@ -97,9 +114,18 @@ def main():
97114 else :
98115 model_args , data_args , training_args = parser .parse_args_into_dataclasses ()
99116
117+ if training_args .tensorwise_offload_optimizer :
118+ mock_offload_optimizer ()
119+
100120 training_args .print_config (model_args , "Model" )
101121 training_args .print_config (data_args , "Data" )
102122
123+ if training_args .pre_alloc_memory > 0 :
124+ memory_size = int (training_args .pre_alloc_memory * 1024 * 1024 * 1024 )
125+ x = paddle .empty ([memory_size ], dtype = paddle .uint8 )
126+ logger .info (f"pre_alloc_memory size { x .shape } " )
127+ del x
128+
103129 # Setup GPU & distributed training
104130 paddle .set_device (training_args .device )
105131 set_seed (seed = training_args .seed )
@@ -171,6 +197,7 @@ def main():
171197 model_config .max_sequence_length = training_args .max_seq_len
172198 model_config .num_nextn_predict_layers = model_args .num_nextn_predict_layers
173199 model_config ._attn_implementation = model_args .attn_impl
200+ model_config .gradient_accumulation_steps = training_args .gradient_accumulation_steps
174201 logger .info (f"Final model config: { model_config } " )
175202 logger .info ("Creating model" )
176203
@@ -181,6 +208,11 @@ def main():
181208
182209 model_class = AutoModelForCausalLMPipe
183210
211+ model_config .using_flex_token = model_args .using_flex_token
212+ model_config .using_fake_gate = model_args .using_fake_gate
213+ model_config .moe_subbatch_token_num = model_args .moe_subbatch_token_num
214+ model_config .aux_loss_alpha = model_args .aux_loss_alpha
215+
184216 if model_args .continue_training and not training_args .autotuner_benchmark :
185217 model = model_class .from_pretrained (
186218 model_args .model_name_or_path ,
@@ -309,6 +341,17 @@ def neft_post_hook(module, input, output):
309341 training_args .logging_strategy = IntervalStrategy .STEPS
310342 training_args .logging_steps = int (training_args .max_steps / training_args .num_train_epochs )
311343
344+ callbacks = []
345+
346+ if getattr (model_config , "topk_method" , None ) == "noaux_tc" :
347+ # deepseek_v3 finetune do not update the bias, so set lr to 0.0
348+ callbacks += [MoECorrectionBiasAdjustCallback (lr = 0.0 )]
349+
350+ if training_args .use_expert_parallel :
351+ callbacks += [MoeExpertsGradScaleCallback (training_args )]
352+
353+ print ("callbacks:" , callbacks , flush = True )
354+
312355 trainer = SFTTrainer (
313356 model = model ,
314357 args = training_args ,
@@ -319,6 +362,7 @@ def neft_post_hook(module, input, output):
319362 data_collator = data_collator ,
320363 do_generation = data_args .eval_with_do_generation ,
321364 data_args = data_args ,
365+ callbacks = callbacks ,
322366 )
323367 trainable_parameters = [
324368 p for p in model .parameters () if not p .stop_gradient or ("quantization_linear" in p .name and "w_1" in p .name )
0 commit comments