diff --git a/examples/experiments/paddlefleet/glm45_provider.py b/examples/experiments/paddlefleet/glm45_provider.py index 556ff8eaf7..b885f198d3 100644 --- a/examples/experiments/paddlefleet/glm45_provider.py +++ b/examples/experiments/paddlefleet/glm45_provider.py @@ -145,6 +145,12 @@ class GLM45AirModelDebugProvider(GLM45AirModelProvider106B): apply_rope_fusion: bool = True +@dataclass +class GLM45AirModelDebugProviderFP8(GLM45AirModelDebugProvider): + fp8: str = "e4m3" + moe_shared_expert_overlap: True + + @dataclass class GLM45AirModelSingleCardDebugProvider(GLMMoEModelProvider): """ diff --git a/examples/experiments/paddlefleet/run_pretrain.py b/examples/experiments/paddlefleet/run_pretrain.py index 82e473a3d2..e07cb0996b 100644 --- a/examples/experiments/paddlefleet/run_pretrain.py +++ b/examples/experiments/paddlefleet/run_pretrain.py @@ -51,6 +51,7 @@ from glm45_provider import ( GLM45AirModelDebugProvider, + GLM45AirModelDebugProviderFP8, GLM45AirModelSingleCardDebugProvider, ) from qwen_provider import Qwen3MoEModelSingleCardProvider @@ -544,6 +545,8 @@ def main(): model_provider = GLM45AirModelSingleCardDebugProvider() elif training_args.model_provider_type == "GLM_muiti_cards": model_provider = GLM45AirModelDebugProvider() + elif training_args.model_provider_type == "GLM_muiti_cards_fp8": + model_provider = GLM45AirModelDebugProviderFP8() elif training_args.model_provider_type == "qwen_single_card": training_args.save_checkpoint_format = None model_provider = Qwen3MoEModelSingleCardProvider() diff --git a/paddleformers/trainer/trainer_callback.py b/paddleformers/trainer/trainer_callback.py index 9c28720b1e..31afcf6d92 100644 --- a/paddleformers/trainer/trainer_callback.py +++ b/paddleformers/trainer/trainer_callback.py @@ -35,6 +35,7 @@ from paddle.distributed.fleet.utils.sequence_parallel_utils import ( is_sequence_parallel_parameter, ) +from paddlefleet.models.gpt import GPTModel from tqdm.auto import tqdm from ..transformers.moe_gate import PretrainedMoEGate @@ -688,14 +689,20 @@ def on_step_begin(self, args, state, control, **kwargs): global skip_count if (not g_shard_bypass_dygraph_optimizer or skip_count == 0) and hasattr(model, "fp8_quant_weight"): + self.moe_weights_name = [] + self.use_fp8 = True + if isinstance(model, GPTModel): + self.use_fp8 = model.use_fp8() + if not self.use_fp8: + return model.fp8_quant_weight(True, quant_transpose=True) optimizer.clear_param_storage("moe_expert") optimizer.clear_param_storage("rms_linear") optimizer.clear_param_storage("memory_attn") optimizer.clear_param_storage("attn_out_project") optimizer.clear_param_storage("shared_expert") - - self.moe_weights_name = [] + if not args.offload_fp8_expert_master_weight: + return for param in optimizer._inner_opt._parameter_list: color = getattr(param, "color", -1) if isinstance(color, dict) and color["color"] == "moe_expert": diff --git a/paddleformers/trainer/training_args.py b/paddleformers/trainer/training_args.py index 014ed70807..503f3b8d20 100644 --- a/paddleformers/trainer/training_args.py +++ b/paddleformers/trainer/training_args.py @@ -1231,6 +1231,10 @@ class TrainingArguments: default=True, metadata={"help": "whether to use auto_parallel intermediate API."}, ) + offload_fp8_expert_master_weight: bool = field( + default=True, + metadata={"help": "Offload FP8 expert weights."}, + ) use_cache: bool = field( default=False,