diff --git a/examples/experiments/auto_parallel/qwen2/run_pretrain.py b/examples/experiments/auto_parallel/qwen2/run_pretrain.py new file mode 100644 index 00000000000..a920cb183ea --- /dev/null +++ b/examples/experiments/auto_parallel/qwen2/run_pretrain.py @@ -0,0 +1,621 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +GPT/Llama auto parallel pretraining scripts. +""" +import os +import sys +from dataclasses import dataclass, field +from typing import List, Optional + +import paddle + +from paddleformers.data.causal_dataset import ( + build_train_valid_test_datasets, + check_data_split, + print_rank_0, +) +from paddleformers.trainer import PdArgumentParser, get_last_checkpoint +from paddleformers.trainer.trainer import Trainer, set_seed +from paddleformers.trainer.trainer_utils import IntervalStrategy +from paddleformers.trainer.training_args import TrainingArguments +from paddleformers.trainer.utils.doc import add_start_docstrings +from paddleformers.transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForCausalLMPipe, + AutoTokenizer, + CosineAnnealingWithWarmupDecay, + LinearAnnealingWithWarmupDecay, +) +from paddleformers.transformers.configuration_utils import LlmMetaConfig +from paddleformers.utils.log import logger + + +@dataclass +@add_start_docstrings(TrainingArguments.__doc__) +class PreTrainingArguments(TrainingArguments): + min_learning_rate: float = field( + default=1e-5, + metadata={"help": "Minimum learning rate deacyed to."}, + ) + decay_steps: float = field( + default=None, + metadata={ + "help": "The steps use to control the learing rate. If the step > decay_steps, will use the min_learning_rate." + }, + ) + enable_linear_fused_grad_add: bool = field( + default=False, + metadata={ + "help": "Enable fused linear grad add strategy, which will reduce elementwise add for grad accumulation in the backward of nn.Linear ." + }, + ) + pipeline_schedule_mode: str = field( + default="1F1B", metadata={"help": "The pipeline schedule mode, support FThenB, 1F1B, VPP and Eager-1F1B."} + ) + sr: Optional[int] = field(default=0, metadata={"help": "The count of chunks without recompute."}) + virtual_pipeline_seg_method: str = field( + default="LlamaDecoderLayerAuto", + metadata={"help": "The seg method of splitting pp layer for virtual pipeline."}, + ) + # NOTE(gongenlei): new add autotuner_benchmark + autotuner_benchmark: bool = field( + default=False, + metadata={"help": "Weather to run benchmark by autotuner. True for from_scratch and pad_max_length."}, + ) + fine_grained_log: bool = field( + default=False, + metadata={"help": "whether print find-grained performance log"}, + ) + lazy_init: bool = field( + default=False, + metadata={"help": "whether use lazy init for model parameters"}, + ) + n_microbatches: int = field( + default=1, + metadata={"help": "Control the num of microbatches in one pp step."}, + ) + + unified_checkpoint: bool = field( + default=True, + metadata={"help": "Enable fused linear grad add strategy."}, + ) + + def __post_init__(self): + super().__post_init__() + # NOTE(gongenlei): new add autotuner_benchmark + if self.autotuner_benchmark: + self.max_steps = 5 + self.do_train = True + self.do_export = False + self.do_predict = False + self.do_eval = False + self.overwrite_output_dir = True + self.load_best_model_at_end = False + self.report_to = [] + self.save_strategy = IntervalStrategy.NO + self.evaluation_strategy = IntervalStrategy.NO + + if self.enable_auto_parallel: + logger.info(self.strategy) + + +@dataclass +class DataArguments: + """ + Arguments pertaining to what data we are going to input our model for training and evaluating. + Using `PdArgumentParser` we can turn this class into argparse arguments to be able to + specify them on the command line. + """ + + input_dir: str = field( + default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} + ) + split: str = field(default="949,50,1", metadata={"help": "Train/valid/test data split."}) + + max_seq_length: int = field( + default=1024, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + share_folder: bool = field( + default=False, + metadata={"help": "Use share folder for data dir and output dir on multi machine."}, + ) + + data_impl: str = field(default="mmap", metadata={"help": "The format of the preprocessed data."}) + skip_warmup: bool = field( + default=True, + metadata={"help": "Whether to skip the warmup process of mmap files."}, + ) + data_cache: str = field(default=None, metadata={"help": "The path of the cached dataset."}) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to pre-train from. + """ + + model_name_or_path: str = field( + default="__internal_testing__/tiny-random-llama", + metadata={ + "help": "Path to pretrained model or model identifier from https://paddleformers.readthedocs.io/zh/latest/model_zoo/transformers.html" + }, + ) + tokenizer_name_or_path: Optional[str] = field( + default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} + ) + + use_fast_layer_norm: bool = field( + default=False, + metadata={"help": "GPT3 model, use fast layernorm"}, + ) + + config_name: Optional[str] = field( + default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} + ) + vocab_size: Optional[int] = field( + default=None, + metadata={ + "help": ".Vocabulary size of the Llama model. Defines the number of different tokens that can be represented by the `inputs_ids`" + }, + ) + hidden_size: Optional[int] = field(default=None, metadata={"help": "Dimension of the hidden representations."}) + intermediate_size: Optional[int] = field(default=None, metadata={"help": "Dimension of the MLP representations."}) + num_hidden_layers: Optional[int] = field( + default=None, metadata={"help": "Number of hidden layers in the Transformer encoder."} + ) + num_attention_heads: Optional[int] = field( + default=None, + metadata={"help": "Number of attention heads for each attention layer in the Transformer encoder."}, + ) + use_flash_attention: bool = field( + default=False, + metadata={"help": "use_flash_attention"}, + ) + use_fused_rms_norm: bool = field( + default=False, + metadata={"help": "llama, use_fused_rms_norm"}, + ) + fuse_attention_qkv: bool = field( + default=False, + metadata={"help": "whether to fuse attention qkv"}, + ) + fuse_attention_ffn: bool = field( + default=False, + metadata={"help": "whether to fuse first up and gate proj in mlp block"}, + ) + recompute_granularity: str = field( + default="full", + metadata={"help": "Choose among ['full', 'core_attn', 'full_attn']"}, + ) + virtual_pp_degree: int = field( + default=1, + metadata={"help": "virtual_pp_degree"}, + ) + continue_training: bool = field( + default=False, + metadata={ + "help": "Pre-training from existing paddlenlp model weights. Default False and model will train from scratch. If set True, the model_name_or_path argument must exist in the paddlenlp models." + }, + ) + use_fused_rope: Optional[bool] = field( + default=False, + metadata={"help": "Enable rope fusion or not."}, + ) + no_recompute_layers: Optional[List[int]] = field( + default=None, + metadata={"help": "Specify the full transformer layers that should not be recomputed."}, + ) + pp_recompute_interval: int = field( + default=1, + metadata={ + "help": "The interval for the number of layers at which recomputation occurs. A value of 0 indicates no recomputation. Default is 0." + }, + ) + recompute_use_reentrant: bool = field( + default=False, + metadata={"help": "recompute_use_reentrant"}, + ) + hidden_dropout_prob: float = field(default=0.1, metadata={"help": "The hidden dropout prob."}) + attention_probs_dropout_prob: float = field(default=0.1, metadata={"help": "The attention hidden dropout prob."}) + + +def create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=True, +): + + check_data_split(data_args.split, training_args.do_train, training_args.do_eval, training_args.do_predict) + + train_val_test_num_samples = [ + training_args.per_device_train_batch_size + * training_args.dataset_world_size + * training_args.max_steps + * training_args.gradient_accumulation_steps, + training_args.per_device_eval_batch_size + * training_args.dataset_world_size + * training_args.eval_iters + * (training_args.max_steps // training_args.eval_steps + 1), + training_args.per_device_eval_batch_size * training_args.dataset_world_size * training_args.test_iters, + ] + + print_rank_0(" > datasets target sizes (minimum size):") + if training_args.do_train: + print_rank_0(" train: {}".format(train_val_test_num_samples[0])) + if training_args.do_eval: + print_rank_0(" validation: {}".format(train_val_test_num_samples[1])) + if training_args.do_predict: + print_rank_0(" test: {}".format(train_val_test_num_samples[2])) + + # Build the datasets. + train_dataset, valid_dataset, test_dataset = build_train_valid_test_datasets( + data_prefix=data_file, + data_impl=data_args.data_impl, + splits_string=data_args.split, + train_val_test_num_samples=train_val_test_num_samples, + seq_length=data_args.max_seq_length, + seed=training_args.seed, + skip_warmup=data_args.skip_warmup, + share_folder=data_args.share_folder, + data_cache_path=data_args.data_cache, + need_data=need_data, + ) + + def print_dataset(data, mode="train"): + logger.info(f"Sample data for {mode} mode.") + input_ids = data["text"] + + logger.info(tokenizer._decode(input_ids)) + + from paddleformers.data import Stack + + def _collate_data(data, stack_fn=Stack()): + tokens_ = stack_fn([x["text"] for x in data]) + + labels = tokens_[:, 1:] + tokens = tokens_[:, :-1] + + return { + "input_ids": tokens, + "labels": labels, + } + + if need_data: + if training_args.do_train: + print_dataset(train_dataset[0], "train") + if training_args.do_eval: + print_dataset(valid_dataset[0], "valid") + if training_args.do_predict: + print_dataset(test_dataset[0], "test") + + return train_dataset, valid_dataset, test_dataset, _collate_data + + +def get_train_data_file(args): + if len(args.input_dir.split()) > 1: + # weight-1 data-prefix-1 weight-2 data-prefix-2 ... + return args.input_dir.split() + else: + files = [ + os.path.join(args.input_dir, f) + for f in os.listdir(args.input_dir) + if (os.path.isfile(os.path.join(args.input_dir, f)) and ("_idx.npz" in str(f) or ".idx" in str(f))) + ] + files = [x.replace("_idx.npz", "") for x in files] + files = [x.replace(".idx", "") for x in files] + + if len(files) > 1: + ret = [] + logger.info("You are using multi-dataset:") + for x in files: + ret.append(1.0) + ret.append(x) + logger.info(" > set weight of %s dataset to 1.0" % x) + return ret + + return files + + +class PretrainingTrainer(Trainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_pretraining = True + + +def main(): + parser = PdArgumentParser((ModelArguments, DataArguments, PreTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + do_enable_linear_fused_grad_add = training_args.enable_linear_fused_grad_add + do_enable_mp_async_allreduce = ( + training_args.enable_auto_parallel + and training_args.tensor_parallel_degree > 1 + and "enable_mp_async_allreduce" in training_args.tensor_parallel_config + and not training_args.sequence_parallel + ) + do_enable_sp_async_reduce_scatter = ( + training_args.enable_auto_parallel + and training_args.tensor_parallel_degree > 1 + and training_args.sequence_parallel + and "enable_sp_async_reduce_scatter" in training_args.tensor_parallel_config + ) + if ( + do_enable_linear_fused_grad_add or do_enable_mp_async_allreduce or do_enable_sp_async_reduce_scatter + ) and not training_args.to_static: + from llm.utils.fused_layers import mock_layers + + mock_layers(do_enable_linear_fused_grad_add, do_enable_mp_async_allreduce, do_enable_sp_async_reduce_scatter) + + if model_args.tokenizer_name_or_path is None: + model_args.tokenizer_name_or_path = model_args.model_name_or_path + + if data_args.data_cache is not None: + os.makedirs(data_args.data_cache, exist_ok=True) + + set_seed(seed=training_args.seed) + paddle.set_device(training_args.device) + if paddle.distributed.get_world_size() > 1: + paddle.distributed.init_parallel_env() + + training_args.eval_iters = 10 + training_args.test_iters = training_args.eval_iters * 10 + + # Log model and data config + training_args.print_config(model_args, "Model") + training_args.print_config(data_args, "Data") + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, " + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16 or training_args.bf16}" + ) + + # Detecting last checkpoint. + last_checkpoint = None + if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is not None and training_args.resume_from_checkpoint is None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name_or_path) + config = AutoConfig.from_pretrained(model_args.model_name_or_path) + LlmMetaConfig.set_llm_config(config, training_args) + config.use_fast_layer_norm = model_args.use_fast_layer_norm + + config.seq_length = data_args.max_seq_length + # There are some technique extend RotaryEmbedding context. so don't change max_position_embeddings + if not model_args.continue_training: + config.max_position_embeddings = max(config.max_position_embeddings, data_args.max_seq_length) + + if not model_args.continue_training: + config.vocab_size = max(config.vocab_size, ((tokenizer.vocab_size - 1) // 128 + 1) * 128) + logger.info(f"Reset vocab size to {config.vocab_size} for batter amp performance.") + + # Config for model using dropout, such as GPT. + if hasattr(config, "use_dualpipev"): + # NOTE(zhangyuqin): In Paddle, the segmentation and scheduling of pipeline parallel + # models are separate. Therefore, first we need to set the flag in the model config + # to perform V-shape segmentation. Second, we need to set the flag in the training_args + # to configure strategy.hybrid_configs to choose the DualPipeV schedule. + config.use_dualpipev = "use_dualpipev" in training_args.pipeline_parallel_config + if hasattr(config, "hidden_dropout_prob"): + config.hidden_dropout_prob = model_args.hidden_dropout_prob + if hasattr(config, "attention_probs_dropout_prob"): + config.attention_probs_dropout_prob = model_args.attention_probs_dropout_prob + + if model_args.no_recompute_layers is not None: + model_args.no_recompute_layers.sort() + + # config.vocab_size = model_args.vocab_size if model_args.vocab_size is not None else config.vocab_size + config.hidden_size = model_args.hidden_size if model_args.hidden_size is not None else config.hidden_size + config.intermediate_size = ( + model_args.intermediate_size if model_args.intermediate_size is not None else config.intermediate_size + ) + config.num_hidden_layers = ( + model_args.num_hidden_layers if model_args.num_hidden_layers is not None else config.num_hidden_layers + ) + config.num_attention_heads = ( + model_args.num_attention_heads if model_args.num_attention_heads is not None else config.num_attention_heads + ) + + config.use_flash_attention = model_args.use_flash_attention + config.use_fused_rms_norm = model_args.use_fused_rms_norm + config.fuse_attention_qkv = model_args.fuse_attention_qkv + config.fuse_attention_ffn = model_args.fuse_attention_ffn + config.recompute_granularity = model_args.recompute_granularity + config.virtual_pp_degree = model_args.virtual_pp_degree + config.sequence_parallel = training_args.sequence_parallel + config.fuse_sequence_parallel_allreduce = training_args.fuse_sequence_parallel_allreduce + config.use_fused_rope = model_args.use_fused_rope + config.no_recompute_layers = model_args.no_recompute_layers + config.pp_recompute_interval = model_args.pp_recompute_interval + config.recompute_use_reentrant = model_args.recompute_use_reentrant + + config.use_recompute = training_args.recompute + config.tensor_parallel_degree = training_args.tensor_parallel_degree + config.tensor_parallel_rank = training_args.tensor_parallel_rank + + config.dp_degree = training_args.data_parallel_degree + config.mp_degree = training_args.tensor_parallel_degree + config.pp_degree = training_args.pipeline_parallel_degree + config.to_static = training_args.to_static + config.fine_grained_log = training_args.fine_grained_log + config.lazy_init = training_args.lazy_init + + if config.sequence_parallel: + assert config.tensor_parallel_degree > 1, "tensor_parallel_degree must be larger than 1 for sequence parallel." + assert ( + config.num_attention_heads % config.sep_parallel_degree == 0 + ), f"num_attention_heads:{config.num_attention_heads} must be divisible by sep_parallel_degree {config.sep_parallel_degree}" + assert ( + config.seq_length % config.context_parallel_degree == 0 + ), f"seq_length:{config.seq_length} must be divisible by context_parallel_degree {config.context_parallel_degree}" + + if training_args.sharding_parallel_config is not None: + # for stage1 overlap optimization + if ( + "enable_stage1_allgather_overlap" in training_args.sharding_parallel_config + or "enable_stage1_broadcast_overlap" in training_args.sharding_parallel_config + ): + from paddle.io.reader import use_pinned_memory + + use_pinned_memory(False) + + if ( + "replace_with_parallel_cross_entropy" in training_args.tensor_parallel_config + and config.tensor_parallel_degree > 1 + and config.to_static is False + ): + from llm.utils.replace_ops import replace_cross_entropy + + replace_cross_entropy() + + if training_args.use_intermediate_api: + config.run_single_model = True + config.tensor_parallel_degree = 1 + config.sharding_parallel_degree = 1 + config.sep_parallel_degree = 1 + config.context_parallel_degree = 1 + + print("Final pre-training config:", config) + + # Set the dtype for loading model + dtype = "float32" + if training_args.fp16_opt_level == "O2": + if training_args.fp16: + dtype = "float16" + if training_args.bf16: + dtype = "bfloat16" + + model_class = AutoModelForCausalLM + if not training_args.enable_auto_parallel and training_args.pipeline_parallel_degree > 1: + model_class = AutoModelForCausalLMPipe + if "LLama" in str(config.architectures): + try: + from utils.register_reshard import register_pp_reshard_information + + register_pp_reshard_information(config.num_hidden_layers) + except: + print("Not register llama pp reshard information.") + + architectures_to_check = {"Qwen2Moe", "DeepseekV2", "DeepseekV3"} + if ( + any(architecture in str(config.architectures) for architecture in architectures_to_check) + and training_args.data_parallel_degree > 1 + ): + training_args.use_expert_parallel = True + + if model_args.continue_training: + # NOTE(gongenlei): new add + if training_args.autotuner_benchmark: + model = model_class.from_config(config, dtype=dtype) + else: + model = model_class.from_pretrained( + model_args.model_name_or_path, + config=config, + dtype=dtype, + ) + else: + if training_args.enable_auto_parallel: + with paddle.LazyGuard(): + model = model_class.from_config(config, dtype=dtype) + else: + model = model_class.from_config(config, dtype=dtype) + + # Create the learning_rate scheduler and optimizer + if training_args.decay_steps is None: + training_args.decay_steps = training_args.max_steps + + if training_args.warmup_steps > 0: + warmup_steps = training_args.warmup_steps + else: + warmup_steps = training_args.warmup_ratio * training_args.max_steps + + lr_scheduler = None + if training_args.lr_scheduler_type.value == "cosine": + lr_scheduler = CosineAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + elif training_args.lr_scheduler_type.value == "linear": + lr_scheduler = LinearAnnealingWithWarmupDecay( + max_lr=training_args.learning_rate, + min_lr=training_args.min_learning_rate, + warmup_step=warmup_steps, + decay_step=training_args.decay_steps, + last_epoch=0, + ) + + data_file = get_train_data_file(data_args) + train_dataset, eval_dataset, test_dataset, data_collator = create_pretrained_dataset( + data_args, + training_args, + data_file, + tokenizer, + need_data=training_args.should_load_dataset, + ) + + trainer = PretrainingTrainer( + model=model, + args=training_args, + data_collator=data_collator, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + optimizers=(None, lr_scheduler), + tokenizer=tokenizer, + ) + + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=checkpoint) + + # NOTE(gongenlei): new add + if not training_args.autotuner_benchmark: + metrics = train_result.metrics + if not int(os.getenv("test_ci_no_save_model", 0)): + trainer.save_model() + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + if training_args.do_predict: + test_ret = trainer.predict(test_dataset) + trainer.log_metrics("test", test_ret.metrics) + + +if __name__ == "__main__": + main() diff --git a/paddleformers/transformers/configuration_utils.py b/paddleformers/transformers/configuration_utils.py index 29eb147214e..d8ef493ed7a 100644 --- a/paddleformers/transformers/configuration_utils.py +++ b/paddleformers/transformers/configuration_utils.py @@ -537,6 +537,9 @@ class PretrainedConfig: Whether the model's input and output word embeddings should be tied. Note that this is only relevant if the model has a output word embedding layer. + run_single_model (`bool`, *optional*, defaults to `False`): + Whether to run the model in single card mode. When enabled, all parallel degree configurations will be disabled. + dtype (`str`, *optional*): The `dtype` of the weights. This attribute can be used to initialize the model to a non-default `dtype` (which is normally `float32`) and thus allow for optimal storage allocation. For example, if the saved @@ -601,6 +604,13 @@ def __init__(self, **kwargs): self.use_cache = kwargs.pop("use_cache", False) self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", True) + # for run model in single card mode + self.run_single_model = kwargs.pop("run_single_model", False) + if self.run_single_model: + self.tensor_parallel_degree = 1 + self.sep_parallel_degree = 1 + self.context_parallel_degree = 1 + # for transformers fuse self.fuse_linear = kwargs.pop("fuse_linear", False) self.fuse_attention_qkv = kwargs.pop("fuse_attention_qkv", False) diff --git a/paddleformers/transformers/qwen2/auto_dist_config.py b/paddleformers/transformers/qwen2/auto_dist_config.py new file mode 100644 index 00000000000..3dd0460d996 --- /dev/null +++ b/paddleformers/transformers/qwen2/auto_dist_config.py @@ -0,0 +1,37 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle.distributed as dist + + +def get_dist_config(model, prefix="model."): + if prefix != "": + assert prefix.endswith(".") + config = { + "mp_config": { + "parallelize_plan": { + f"{prefix}model.embed_tokens": dist.RowWiseParallel(), + f"{prefix}model.layers.*.self_attn.q_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.self_attn.k_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.self_attn.v_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.self_attn.o_proj": dist.RowWiseParallel(), + f"{prefix}model.layers.*.mlp.gate_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.mlp.up_proj": dist.ColWiseParallel(), + f"{prefix}model.layers.*.mlp.down_proj": dist.RowWiseParallel(), + f"{prefix}lm_head.weight": dist.RowWiseParallel(), + } + }, + } + + return config diff --git a/paddleformers/transformers/qwen2/modeling.py b/paddleformers/transformers/qwen2/modeling.py index 03b880ee6b7..1ef5d29a5d4 100644 --- a/paddleformers/transformers/qwen2/modeling.py +++ b/paddleformers/transformers/qwen2/modeling.py @@ -48,6 +48,7 @@ TokenClassifierOutput, ) from ..model_utils import PretrainedModel, register_base_model +from .auto_dist_config import get_dist_config from .configuration import Qwen2Config @@ -696,6 +697,10 @@ def forward( attentions=outputs.attentions, ) + def auto_dist_config(self, prefix=""): + assert self.config.run_single_model, "Use `get_dist_config` only in single card mode." + return get_dist_config(self, prefix) + class Qwen2ForSequenceClassification(Qwen2PretrainedModel): def __init__(self, config: Qwen2Config): diff --git a/paddleformers/transformers/tensor_parallel_utils.py b/paddleformers/transformers/tensor_parallel_utils.py index aa7055bb8f6..1610fd605e7 100644 --- a/paddleformers/transformers/tensor_parallel_utils.py +++ b/paddleformers/transformers/tensor_parallel_utils.py @@ -72,20 +72,14 @@ def parallel_matmul( except ImportError: pass - is_fleet_init = True - try: - hcg = fleet.get_hybrid_communicate_group() - model_parallel_group = hcg.get_model_parallel_group() - tensor_parallel_degree = hcg.get_model_parallel_world_size() - except: - is_fleet_init = False - is_logit_weight_distributed = logit_weights.is_distributed # `is_distributed` in static mode is always False if in_declarative_mode() and tensor_parallel_degree > 1: is_logit_weight_distributed = True - if is_fleet_init and tensor_parallel_degree > 1 and is_logit_weight_distributed: + if tensor_parallel_degree > 1 and is_logit_weight_distributed: + hcg = fleet.get_hybrid_communicate_group() + model_parallel_group = hcg.get_model_parallel_group() input_parallel = paddle.distributed.collective._c_identity(lm_output, group=model_parallel_group) if transpose_y: