Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions examples/experiments/paddlefleet/glm45_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,12 @@ class GLM45AirModelDebugProvider(GLM45AirModelProvider106B):
apply_rope_fusion: bool = True


@dataclass
class GLM45AirModelDebugProviderFP8(GLM45AirModelDebugProvider):
fp8: str = "e4m3"
moe_shared_expert_overlap: True


@dataclass
class GLM45AirModelSingleCardDebugProvider(GLMMoEModelProvider):
"""
Expand Down
3 changes: 3 additions & 0 deletions examples/experiments/paddlefleet/run_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

from glm45_provider import (
GLM45AirModelDebugProvider,
GLM45AirModelDebugProviderFP8,
GLM45AirModelSingleCardDebugProvider,
)
from qwen_provider import Qwen3MoEModelSingleCardProvider
Expand Down Expand Up @@ -544,6 +545,8 @@ def main():
model_provider = GLM45AirModelSingleCardDebugProvider()
elif training_args.model_provider_type == "GLM_muiti_cards":
model_provider = GLM45AirModelDebugProvider()
elif training_args.model_provider_type == "GLM_muiti_cards_fp8":
model_provider = GLM45AirModelDebugProviderFP8()
elif training_args.model_provider_type == "qwen_single_card":
training_args.save_checkpoint_format = None
model_provider = Qwen3MoEModelSingleCardProvider()
Expand Down
11 changes: 9 additions & 2 deletions paddleformers/trainer/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
is_sequence_parallel_parameter,
)
from paddlefleet.models.gpt import GPTModel
from tqdm.auto import tqdm

from ..transformers.moe_gate import PretrainedMoEGate
Expand Down Expand Up @@ -688,14 +689,20 @@ def on_step_begin(self, args, state, control, **kwargs):
global skip_count

if (not g_shard_bypass_dygraph_optimizer or skip_count == 0) and hasattr(model, "fp8_quant_weight"):
self.moe_weights_name = []
self.use_fp8 = True
if isinstance(model, GPTModel):
self.use_fp8 = model.use_fp8()
if not self.use_fp8:
return
model.fp8_quant_weight(True, quant_transpose=True)
optimizer.clear_param_storage("moe_expert")
optimizer.clear_param_storage("rms_linear")
optimizer.clear_param_storage("memory_attn")
optimizer.clear_param_storage("attn_out_project")
optimizer.clear_param_storage("shared_expert")

self.moe_weights_name = []
if not args.offload_fp8_expert_master_weight:
return
for param in optimizer._inner_opt._parameter_list:
color = getattr(param, "color", -1)
if isinstance(color, dict) and color["color"] == "moe_expert":
Expand Down
4 changes: 4 additions & 0 deletions paddleformers/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,10 @@ class TrainingArguments:
default=True,
metadata={"help": "whether to use auto_parallel intermediate API."},
)
offload_fp8_expert_master_weight: bool = field(
default=True,
metadata={"help": "Offload FP8 expert weights."},
)

use_cache: bool = field(
default=False,
Expand Down