Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddleformers/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
"TrainerState",
"DEFAULT_PROGRESS_CALLBACK",
"TrainerCallback",
"StepFlexToken",
"FP8QuantWeightCallback",
],
"trainer_utils": [
"get_last_checkpoint",
Expand Down
12 changes: 12 additions & 0 deletions paddleformers/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@
RowParallelQuantizationLinear,
)

try:
from ..quantization.quantization_linear import QuantizationLinear
except:
QuantizationLinear = None
try:
from paddle.distributed.fleet.utils.sequence_parallel_utils import (
register_sequence_parallel_allreduce_hooks,
Expand Down Expand Up @@ -201,6 +205,14 @@
DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

# Name of the files used for checkpointing
TRAINING_ARGS_NAME = "training_args.bin"
TRAINER_STATE_NAME = "trainer_state.json"

SCHEDULER_NAME = "scheduler.pdparams"
SCALER_NAME = "scaler.pdparams"


if is_datasets_available():
import datasets

Expand Down
66 changes: 66 additions & 0 deletions paddleformers/trainer/trainer_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
"""
import dataclasses
import json
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

import numpy as np
from tqdm.auto import tqdm

from paddleformers.transformers.moe_utils import offload, reload
from ..utils.log import logger
from .trainer_utils import IntervalStrategy, has_length
from .training_args import TrainingArguments
Expand All @@ -39,6 +41,8 @@
"ProgressCallback",
"PrinterCallback",
"EarlyStoppingCallback",
"StepFlexToken",
"FP8QuantWeightCallback",
]


Expand Down Expand Up @@ -608,3 +612,65 @@ def on_evaluate(self, args, state, control, metrics, **kwargs):
self.check_metric_value(args, state, control, metric_value)
if self.early_stopping_patience_counter >= self.early_stopping_patience:
control.should_training_stop = True


class StepFlexToken(TrainerCallback):
def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
):
model = kwargs.pop("model")
if hasattr(model, "step_flex_token"):
model.step_flex_token(state.global_step)


g_shard_bypass_dygraph_optimizer = int(os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0))


def enable_in_dict_config(config, key):
"""enable_in_dict_config"""
return key in config and config[key]


skip_count = 0


class FP8QuantWeightCallback(TrainerCallback):
"""
FP8QuantWeightCallback
"""

def on_step_begin(self, args, state, control, **kwargs):
"""
每个step开始前把专家参数quant成fp8q
"""
model = kwargs["model"]
optimizer = kwargs["optimizer"]
global skip_count

if not g_shard_bypass_dygraph_optimizer or skip_count == 0:
model.fp8_quant_weight(True)
optimizer.clear_param_storage("moe_expert")
optimizer.clear_param_storage("rms_linear")
optimizer.clear_param_storage("memory_attn")
optimizer.clear_param_storage("attn_out_project")
optimizer.clear_param_storage("shared_expert")

self.moe_weights_name = []
for param in optimizer._inner_opt._parameter_list:
color = getattr(param, "color", -1)
if isinstance(color, dict) and color["color"] == "moe_expert":
self.moe_weights_name.append(param.name)

for name in self.moe_weights_name:
offload(optimizer._master_weights[name])

skip_count += 1

def on_optimizer_begin(self, args, state, control, **kwargs):
optimizer = kwargs["optimizer"]
for name in self.moe_weights_name:
reload(optimizer._master_weights[name])
11 changes: 5 additions & 6 deletions paddleformers/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from paddle.distributed import fleet

from ..utils.env import PREFIX_CHECKPOINT_DIR
from ..utils.fault_tolerance import is_ft_env
from ..utils.log import logger
from ..utils.pdc_sdk import FLASH_DEVICE
from .trainer_utils import (
Expand Down Expand Up @@ -1397,12 +1398,7 @@ def is_segment_parallel_supported():
else:
order = ["dp", "sharding", "pp", "mp"]
if self.use_expert_parallel:
if self.moe_sharding_parallel_degree >= 1 and self.expert_parallel_degree > 1:
order.insert(-1, "ep")
sd_idx = order.index("sharding")
# if pp_first, the order = ["dp", "pp", "moe_sharding", "sharding", "sep", "ep", "mp"]
# if sharding_first, the order is ["dp", "moe_sharding", "sharding", "pp", "sep", "ep", "mp"]
order.insert(sd_idx, "moe_sharding")
order = order[1:-1] + ["dp", "mp"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里为什么要删除,删除之后会不会对原来逻辑有影响

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不修改的话会报错

  File "/PaddleFormers/paddleformers/trainer/training_args.py", line 1561, in __post_init__
    self.add_moe_comm_group()
  File "/PaddleFormers/paddleformers/trainer/training_args.py", line 2071, in add_moe_comm_group
    sharding_parallel_groups = topo.get_comm_list("sharding")
  File "/py3.10/lib/python3.10/site-packages/paddle/distributed/fleet/base/topology.py", line 227, in get_comm_list
    assert axis_name in self._parallel_names
AssertionError


if is_segment_parallel_supported():
hybrid_configs = {
Expand Down Expand Up @@ -1556,6 +1552,9 @@ def is_segment_parallel_supported():
fleet.init(is_collective=True, strategy=strategy)
logger.info(strategy)

if self.expert_parallel_degree > 1:
self.add_moe_comm_group()
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉的话会报错

 File "/lib/python3.10/site-packages/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py", line 79, in build_layer
    return self.layer_func(*self.inputs, **{**self.kwargs, **extra_kwargs})
  File "/PaddleFormers/paddleformers/transformers/deepseek_v2/modeling.py", line 2275, in __init__
    DeepseekV2MoE(
  File "/PaddleFormers/paddleformers/transformers/deepseek_v2/modeling.py", line 1018, in __init__
    super().__init__(
  File "/PaddleFormers/paddleformers/transformers/moe_layer.py", line 225, in __init__
    self.moe_group = dist.fleet.get_hybrid_communicate_group().expert_parallel_group
AttributeError: 'HybridCommunicateGroup' object has no attribute 'expert_parallel_group'. Did you mean: 'get_data_parallel_group'?


elif self.enable_auto_parallel:
self.tensor_parallel_degree = max(self.tensor_parallel_degree, 1)
self.sep_parallel_degree = max(self.sep_parallel_degree, 1)
Expand Down
Loading
Loading