diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index fa71907e3c..6b2a41e9d1 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -126,6 +126,30 @@ class LoraArguments: ) +@dataclass +class DistillationArguments: + r""" + Arguments pertaining to the distillation training. + """ + + distilling_lambda: float = field( + default=0.5, + metadata={"help": "The lambda parameter in the distilling loss."}, + ) + distilling_temperature: float = field( + default=1.0, + metadata={"help": "The temperature parameter in the distilling softmax."}, + ) + teacher_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the teacher model used for the distilling."}, + ) + teacher_model_adapters: Optional[str] = field( + default=None, + metadata={"help": "Path to the adapters of the teacher model."}, + ) + + @dataclass class RLHFArguments: r""" @@ -387,7 +411,14 @@ class SwanLabArguments: @dataclass class FinetuningArguments( - FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, ApolloArguments, BAdamArgument, SwanLabArguments + FreezeArguments, + LoraArguments, + RLHFArguments, + GaloreArguments, + ApolloArguments, + BAdamArgument, + SwanLabArguments, + DistillationArguments, ): r""" Arguments pertaining to which techniques we are going to fine-tuning with. @@ -397,7 +428,7 @@ class FinetuningArguments( default=False, metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."}, ) - stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field( + stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "distillation"] = field( default="sft", metadata={"help": "Which stage will be performed in training."}, ) diff --git a/src/llamafactory/train/distil/__init__.py b/src/llamafactory/train/distil/__init__.py new file mode 100644 index 0000000000..53311fcdc2 --- /dev/null +++ b/src/llamafactory/train/distil/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .workflow import run_distillation + + +__all__ = ["run_distillation"] diff --git a/src/llamafactory/train/distil/trainer.py b/src/llamafactory/train/distil/trainer.py new file mode 100644 index 0000000000..7c93f90bd5 --- /dev/null +++ b/src/llamafactory/train/distil/trainer.py @@ -0,0 +1,158 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from types import MethodType +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +from transformers import Seq2SeqTrainer +from typing_extensions import override + +from ...extras import logging +from ...extras.packages import is_transformers_version_greater_than +from ..callbacks import SaveProcessorCallback +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin + + from ...hparams import FinetuningArguments + + +logger = logging.get_logger(__name__) + + +class CustomDistillationTrainer(Seq2SeqTrainer): + r""" + Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE. + """ + + def __init__( + self, + teacher_model: Union["PreTrainedModel", torch.nn.Module], + finetuning_args: "FinetuningArguments", + processor: Optional["ProcessorMixin"], + **kwargs, + ): + if is_transformers_version_greater_than("4.46"): + kwargs["processing_class"] = kwargs.pop("tokenizer") + else: + self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer") + + super().__init__(**kwargs) + + if teacher_model is not None: + if self.is_deepspeed_enabled: + if not ( + getattr(teacher_model, "is_loaded_in_8bit", False) + or getattr(teacher_model, "is_loaded_in_4bit", False) + ): # quantized models are already set on the correct device + self.teacher_model = self._prepare_deepspeed(teacher_model) + else: + self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) + self.finetuning_args = finetuning_args + + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) + + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.add_callback(BAdamCallback) + + @override + def create_optimizer(self) -> "torch.optim.Optimizer": + if self.optimizer is None: + self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() + + @override + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) + + @override + def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + if self.finetuning_args.disable_shuffling: + return torch.utils.data.SequentialSampler(self.train_dataset) + + return super()._get_train_sampler() + + @override + def compute_loss( + self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs + ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: + labels = inputs.get("labels") + padding_mask = labels.eq(-100) + label_loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs) + self.teacher_model.eval() + with torch.no_grad(): + teacher_outputs = self.teacher_model(**inputs) + # Shape: (batch_size, seq_len, vocab_size) + teacher_prob = torch.nn.functional.softmax( + teacher_outputs.logits / self.finetuning_args.distilling_temperature, dim=-1 + ) + student_logprob = torch.nn.functional.log_softmax( + outputs.logits / self.finetuning_args.distilling_temperature, dim=-1 + ) + kl_losses = (teacher_prob * (teacher_prob.log() - student_logprob)).sum(dim=-1) + kl_losses.masked_fill_(padding_mask, 0) + num_active_elements = padding_mask.numel() - padding_mask.long().sum() + loss = ( + self.finetuning_args.distilling_lambda + * kl_losses.mean() + / (num_active_elements * student_logprob.shape[-1]) + + label_loss + ) + + if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False): + loss = loss / self.args.gradient_accumulation_steps + + return (loss, outputs) if return_outputs else loss + + def _prepare_deepspeed(self, model: "PreTrainedModel"): + import deepspeed # type: ignore + + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) + + if model is not None: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model diff --git a/src/llamafactory/train/distil/workflow.py b/src/llamafactory/train/distil/workflow.py new file mode 100644 index 0000000000..cb64865852 --- /dev/null +++ b/src/llamafactory/train/distil/workflow.py @@ -0,0 +1,140 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List, Optional + +from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer +from ...extras.constants import IGNORE_INDEX +from ...extras.logging import get_logger +from ...extras.misc import calculate_tps, get_logits_processor +from ...extras.ploting import plot_loss +from ...hparams import FinetuningArguments, ModelArguments +from ...model import load_model, load_tokenizer +from ..trainer_utils import create_modelcard_and_push +from .trainer import CustomDistillationTrainer + + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + + from ...hparams import DataArguments, GeneratingArguments + + +logger = get_logger(__name__) + + +def run_distillation( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: Optional[List["TrainerCallback"]] = None, +): + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + template = get_template_and_fix_tokenizer(tokenizer, data_args) + dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module) + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + + # Load teacher model + # TODO teacher_model_quantization_bit + teacher_model_args = ModelArguments.copyfrom( + model_args, + model_name_or_path=finetuning_args.teacher_model, + adapter_name_or_path=finetuning_args.teacher_model_adapters, + ) + teacher_finetuning_args = FinetuningArguments() + teacher_model = load_model(tokenizer, teacher_model_args, teacher_finetuning_args, is_trainable=False) + # Compare model and teacher tokenizer + teacher_tokenizer = load_tokenizer(teacher_model_args)["tokenizer"] + assert ( + teacher_tokenizer.get_vocab() == tokenizer.get_vocab() + ), "The teacher's and student's tokenizers must have the same vocabulary dictionary." + + if getattr(model, "is_quantized", False) and not training_args.do_train: + setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction + + # TODO handling `prepare_decoder_input_ids_from_labels`. + data_collator = SFTDataCollatorWith4DAttentionMask( + template=template, + model=model if not training_args.predict_with_generate else None, + pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + block_diag_attn=model_args.block_diag_attn, + attn_implementation=getattr(model.config, "_attn_implementation", None), + compute_dtype=model_args.compute_dtype, + **tokenizer_module, + ) + + # Override the decoding parameters of Seq2SeqTrainer + training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len + training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams + training_args.remove_unused_columns = False # important for multimodal dataset + + # Initialize our Trainer + trainer = CustomDistillationTrainer( + model=model, + teacher_model=teacher_model, + args=training_args, + finetuning_args=finetuning_args, + data_collator=data_collator, + callbacks=callbacks, + **dataset_module, + **tokenizer_module, + ) + + # Keyword arguments for `model.generate` + gen_kwargs = generating_args.to_dict(obey_generation_config=True) + gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids + gen_kwargs["pad_token_id"] = tokenizer.pad_token_id + gen_kwargs["logits_processor"] = get_logits_processor() + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + if finetuning_args.include_effective_tokens_per_second: + train_result.metrics["effective_tokens_per_sec"] = calculate_tps( + dataset_module["train_dataset"], train_result.metrics, stage="sft" + ) + + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"]) + + if training_args.predict_with_generate: + tokenizer.padding_side = "left" # use left-padding in generation + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Predict + if training_args.do_predict: + logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.") + predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs) + trainer.log_metrics("predict", predict_results.metrics) + trainer.save_metrics("predict", predict_results.metrics) + trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens) + + # Create model card + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index bbbef1cf54..a51926a5b4 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -26,6 +26,7 @@ from ..hparams import get_infer_args, get_ray_args, get_train_args, read_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback +from .distil import run_distillation from .dpo import run_dpo from .kto import run_kto from .ppo import run_ppo @@ -72,6 +73,8 @@ def _training_function(config: Dict[str, Any]) -> None: run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) elif finetuning_args.stage == "kto": run_kto(model_args, data_args, training_args, finetuning_args, callbacks) + elif finetuning_args.stage == "distillation": + run_distillation(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) else: raise ValueError(f"Unknown task: {finetuning_args.stage}.")