From a5497b4dbc6eac33833e39ab582b9fd96425823c Mon Sep 17 00:00:00 2001 From: marko1616 Date: Sat, 4 Jan 2025 00:13:09 +0800 Subject: [PATCH 1/4] Basic distilling. --- src/llamafactory/hparams/finetuning_args.py | 24 ++- src/llamafactory/train/distilling/__init__.py | 18 ++ src/llamafactory/train/distilling/trainer.py | 177 ++++++++++++++++++ src/llamafactory/train/distilling/workflow.py | 140 ++++++++++++++ src/llamafactory/train/tuner.py | 3 + 5 files changed, 360 insertions(+), 2 deletions(-) create mode 100644 src/llamafactory/train/distilling/__init__.py create mode 100644 src/llamafactory/train/distilling/trainer.py create mode 100644 src/llamafactory/train/distilling/workflow.py diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 29e91a27c3..9247a685d2 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -126,6 +126,26 @@ class LoraArguments: ) +@dataclass +class DistillingArguments: + r""" + Arguments pertaining to the distilling training. + """ + + distilling_lambda: float = field( + default=0.5, + metadata={"help": "The lambda parameter in the distilling loss."}, + ) + teacher_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the teacher model used for the distilling."}, + ) + teacher_model_adapters: Optional[str] = field( + default=None, + metadata={"help": "Path to the adapters of the teacher model."}, + ) + + @dataclass class RLHFArguments: r""" @@ -334,7 +354,7 @@ class SwanLabArguments: @dataclass class FinetuningArguments( - FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments + FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments, DistillingArguments ): r""" Arguments pertaining to which techniques we are going to fine-tuning with. @@ -344,7 +364,7 @@ class FinetuningArguments( default=False, metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."}, ) - stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field( + stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "distilling"] = field( default="sft", metadata={"help": "Which stage will be performed in training."}, ) diff --git a/src/llamafactory/train/distilling/__init__.py b/src/llamafactory/train/distilling/__init__.py new file mode 100644 index 0000000000..5cc170c3bf --- /dev/null +++ b/src/llamafactory/train/distilling/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .workflow import run_distilling + + +__all__ = ["distilling"] diff --git a/src/llamafactory/train/distilling/trainer.py b/src/llamafactory/train/distilling/trainer.py new file mode 100644 index 0000000000..45e6a73c73 --- /dev/null +++ b/src/llamafactory/train/distilling/trainer.py @@ -0,0 +1,177 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from types import MethodType +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import Seq2SeqTrainer +from typing_extensions import override + +from ...extras import logging +from ...extras.constants import IGNORE_INDEX +from ...extras.packages import is_transformers_version_greater_than +from ..callbacks import SaveProcessorCallback +from ..trainer_utils import create_custom_optimizer, create_custom_scheduler + + +if TYPE_CHECKING: + from torch.utils.data import Dataset + from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin + from transformers.trainer import PredictionOutput + + from ...hparams import FinetuningArguments + + +logger = logging.get_logger(__name__) + + +class CustomDistillingTrainer(Seq2SeqTrainer): + r""" + Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE. + """ + + def __init__( + self, + teacher_model: Union["PreTrainedModel", torch.nn.Module], + finetuning_args: "FinetuningArguments", + processor: Optional["ProcessorMixin"], + **kwargs, + ): + if is_transformers_version_greater_than("4.46"): + kwargs["processing_class"] = kwargs.pop("tokenizer") + else: + self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer") + + self.teacher_model = teacher_model + + super().__init__(**kwargs) + self.finetuning_args = finetuning_args + + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) + + if finetuning_args.use_badam: + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) + self.add_callback(BAdamCallback) + + @override + def create_optimizer(self) -> "torch.optim.Optimizer": + if self.optimizer is None: + self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() + + @override + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) + + @override + def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: + if self.finetuning_args.disable_shuffling: + return torch.utils.data.SequentialSampler(self.train_dataset) + + return super()._get_train_sampler() + + @override + def compute_loss( + self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs + ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: + label_loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs) + with torch.no_grad(): + teacher_outputs = self.teacher_model(**inputs) + # Shape: (batch_size, seq_len, vocab_size) + teacher_prob = torch.nn.functional.softmax(teacher_outputs.logits, dim=-1) + student_logprob = torch.nn.functional.log_softmax(outputs.logits, dim=-1) + kl_loss = teacher_prob * (teacher_prob.log() - student_logprob) + loss = self.finetuning_args.distilling_lambda * kl_loss.mean() + label_loss + + if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False): + loss = loss / self.args.gradient_accumulation_steps + + return (loss, outputs) if return_outputs else loss + + @override + def prediction_step( + self, + model: "torch.nn.Module", + inputs: Dict[str, Union["torch.Tensor", Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + **gen_kwargs, + ) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]: + r""" + Removes the prompt part in the generated tokens. + + Subclass and override to inject custom behavior. + """ + if self.args.predict_with_generate: # do not pass labels to model when generate + labels = inputs.pop("labels", None) + else: + labels = inputs.get("labels") + + loss, generated_tokens, _ = super().prediction_step( + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs + ) + if generated_tokens is not None and self.args.predict_with_generate: + generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id + generated_tokens = generated_tokens.contiguous() + + return loss, generated_tokens, labels + + def save_predictions( + self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True + ) -> None: + r""" + Saves model predictions to `output_dir`. + + A custom behavior that not contained in Seq2SeqTrainer. + """ + if not self.is_world_process_zero(): + return + + output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") + logger.info_rank0(f"Saving prediction results to {output_prediction_file}") + + labels = np.where( + predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id + ) + preds = np.where( + predict_results.predictions != IGNORE_INDEX, + predict_results.predictions, + self.processing_class.pad_token_id, + ) + + for i in range(len(preds)): + pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0] + if len(pad_len): # move pad token to last + preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) + + decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False) + decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens) + decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens) + + with open(output_prediction_file, "w", encoding="utf-8") as f: + for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels): + f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n") diff --git a/src/llamafactory/train/distilling/workflow.py b/src/llamafactory/train/distilling/workflow.py new file mode 100644 index 0000000000..8df5d27463 --- /dev/null +++ b/src/llamafactory/train/distilling/workflow.py @@ -0,0 +1,140 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List, Optional + +from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer +from ...extras.constants import IGNORE_INDEX +from ...extras.logging import get_logger +from ...extras.misc import calculate_tps, get_logits_processor +from ...extras.ploting import plot_loss +from ...model import load_model, load_tokenizer +from ...hparams import ModelArguments, FinetuningArguments +from ..trainer_utils import create_modelcard_and_push +from .trainer import CustomDistillingTrainer + + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + + from ...hparams import DataArguments, GeneratingArguments + + +logger = get_logger(__name__) + + +def run_distilling( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: Optional[List["TrainerCallback"]] = None, +): + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + template = get_template_and_fix_tokenizer(tokenizer, data_args) + dataset_module = get_dataset(template, model_args, data_args, training_args, stage="sft", **tokenizer_module) + model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train) + + # Load teacher model + # TODO teacher_model_quantization_bit + teacher_model_args = ModelArguments.copyfrom( + model_args, + model_name_or_path=finetuning_args.teacher_model, + adapter_name_or_path=finetuning_args.teacher_model_adapters, + ) + teacher_finetuning_args = FinetuningArguments() + teacher_model = load_model( + tokenizer, teacher_model_args, teacher_finetuning_args, is_trainable=False + ) + # Compare model and teacher tokenizer + teacher_tokenizer = load_tokenizer(teacher_model_args)["tokenizer"] + assert teacher_tokenizer.get_vocab() == tokenizer.get_vocab(), "The teacher's and student's tokenizers must have the same vocabulary dictionary." + + if getattr(model, "is_quantized", False) and not training_args.do_train: + setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction + + # TODO handling `prepare_decoder_input_ids_from_labels`. + data_collator = SFTDataCollatorWith4DAttentionMask( + template=template, + model=model if not training_args.predict_with_generate else None, + pad_to_multiple_of=8 if training_args.do_train else None, # for shift short attention + label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id, + block_diag_attn=model_args.block_diag_attn, + attn_implementation=getattr(model.config, "_attn_implementation", None), + compute_dtype=model_args.compute_dtype, + **tokenizer_module, + ) + + # Override the decoding parameters of Seq2SeqTrainer + training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len + training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams + training_args.remove_unused_columns = False # important for multimodal dataset + + # Initialize our Trainer + trainer = CustomDistillingTrainer( + model=model, + teacher_model=teacher_model, + args=training_args, + finetuning_args=finetuning_args, + data_collator=data_collator, + callbacks=callbacks, + **dataset_module, + **tokenizer_module + ) + + # Keyword arguments for `model.generate` + gen_kwargs = generating_args.to_dict(obey_generation_config=True) + gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids + gen_kwargs["pad_token_id"] = tokenizer.pad_token_id + gen_kwargs["logits_processor"] = get_logits_processor() + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + if finetuning_args.include_effective_tokens_per_second: + train_result.metrics["effective_tokens_per_sec"] = calculate_tps( + dataset_module["train_dataset"], train_result.metrics, stage="sft" + ) + + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"]) + + if training_args.predict_with_generate: + tokenizer.padding_side = "left" # use left-padding in generation + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Predict + if training_args.do_predict: + logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.") + predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs) + trainer.log_metrics("predict", predict_results.metrics) + trainer.save_metrics("predict", predict_results.metrics) + trainer.save_predictions(dataset_module["eval_dataset"], predict_results, generating_args.skip_special_tokens) + + # Create model card + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 6c79320e7c..554219d0d2 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -31,6 +31,7 @@ from .pt import run_pt from .rm import run_rm from .sft import run_sft +from .distilling import run_distilling from .trainer_utils import get_swanlab_callback @@ -65,6 +66,8 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) elif finetuning_args.stage == "kto": run_kto(model_args, data_args, training_args, finetuning_args, callbacks) + elif finetuning_args.stage == "distilling": + run_distilling(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) else: raise ValueError(f"Unknown task: {finetuning_args.stage}.") From cb65544cbae9dbad4a32ffcb9af75542492a7b8a Mon Sep 17 00:00:00 2001 From: marko1616 Date: Sat, 4 Jan 2025 00:25:47 +0800 Subject: [PATCH 2/4] Linter. --- src/llamafactory/hparams/finetuning_args.py | 12 +++++++++++- src/llamafactory/train/distilling/__init__.py | 2 +- src/llamafactory/train/distilling/trainer.py | 10 +++++++--- src/llamafactory/train/distilling/workflow.py | 12 ++++++------ src/llamafactory/train/tuner.py | 2 +- 5 files changed, 26 insertions(+), 12 deletions(-) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 9247a685d2..58f6be5d5a 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -136,6 +136,10 @@ class DistillingArguments: default=0.5, metadata={"help": "The lambda parameter in the distilling loss."}, ) + distilling_temperature: float = field( + default=1.0, + metadata={"help": "The temperature parameter in the distilling softmax."}, + ) teacher_model: Optional[str] = field( default=None, metadata={"help": "Path to the teacher model used for the distilling."}, @@ -354,7 +358,13 @@ class SwanLabArguments: @dataclass class FinetuningArguments( - FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument, SwanLabArguments, DistillingArguments + FreezeArguments, + LoraArguments, + RLHFArguments, + GaloreArguments, + BAdamArgument, + SwanLabArguments, + DistillingArguments, ): r""" Arguments pertaining to which techniques we are going to fine-tuning with. diff --git a/src/llamafactory/train/distilling/__init__.py b/src/llamafactory/train/distilling/__init__.py index 5cc170c3bf..ca95c88b54 100644 --- a/src/llamafactory/train/distilling/__init__.py +++ b/src/llamafactory/train/distilling/__init__.py @@ -15,4 +15,4 @@ from .workflow import run_distilling -__all__ = ["distilling"] +__all__ = ["run_distilling"] diff --git a/src/llamafactory/train/distilling/trainer.py b/src/llamafactory/train/distilling/trainer.py index 45e6a73c73..28920e2b56 100644 --- a/src/llamafactory/train/distilling/trainer.py +++ b/src/llamafactory/train/distilling/trainer.py @@ -102,11 +102,15 @@ def compute_loss( with torch.no_grad(): teacher_outputs = self.teacher_model(**inputs) # Shape: (batch_size, seq_len, vocab_size) - teacher_prob = torch.nn.functional.softmax(teacher_outputs.logits, dim=-1) - student_logprob = torch.nn.functional.log_softmax(outputs.logits, dim=-1) + teacher_prob = torch.nn.functional.softmax( + teacher_outputs.logits / self.finetuning_args.distilling_temperature, dim=-1 + ) + student_logprob = torch.nn.functional.log_softmax( + outputs.logits / self.finetuning_args.distilling_temperature, dim=-1 + ) kl_loss = teacher_prob * (teacher_prob.log() - student_logprob) loss = self.finetuning_args.distilling_lambda * kl_loss.mean() + label_loss - + if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False): loss = loss / self.args.gradient_accumulation_steps diff --git a/src/llamafactory/train/distilling/workflow.py b/src/llamafactory/train/distilling/workflow.py index 8df5d27463..e24dcb5dcb 100644 --- a/src/llamafactory/train/distilling/workflow.py +++ b/src/llamafactory/train/distilling/workflow.py @@ -22,8 +22,8 @@ from ...extras.logging import get_logger from ...extras.misc import calculate_tps, get_logits_processor from ...extras.ploting import plot_loss +from ...hparams import FinetuningArguments, ModelArguments from ...model import load_model, load_tokenizer -from ...hparams import ModelArguments, FinetuningArguments from ..trainer_utils import create_modelcard_and_push from .trainer import CustomDistillingTrainer @@ -59,12 +59,12 @@ def run_distilling( adapter_name_or_path=finetuning_args.teacher_model_adapters, ) teacher_finetuning_args = FinetuningArguments() - teacher_model = load_model( - tokenizer, teacher_model_args, teacher_finetuning_args, is_trainable=False - ) + teacher_model = load_model(tokenizer, teacher_model_args, teacher_finetuning_args, is_trainable=False) # Compare model and teacher tokenizer teacher_tokenizer = load_tokenizer(teacher_model_args)["tokenizer"] - assert teacher_tokenizer.get_vocab() == tokenizer.get_vocab(), "The teacher's and student's tokenizers must have the same vocabulary dictionary." + assert ( + teacher_tokenizer.get_vocab() == tokenizer.get_vocab() + ), "The teacher's and student's tokenizers must have the same vocabulary dictionary." if getattr(model, "is_quantized", False) and not training_args.do_train: setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction @@ -95,7 +95,7 @@ def run_distilling( data_collator=data_collator, callbacks=callbacks, **dataset_module, - **tokenizer_module + **tokenizer_module, ) # Keyword arguments for `model.generate` diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 554219d0d2..f5f0df7a0d 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -25,13 +25,13 @@ from ..hparams import get_infer_args, get_train_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback +from .distilling import run_distilling from .dpo import run_dpo from .kto import run_kto from .ppo import run_ppo from .pt import run_pt from .rm import run_rm from .sft import run_sft -from .distilling import run_distilling from .trainer_utils import get_swanlab_callback From fca671bff7f656ee13b9d2976942a9de642dd0d5 Mon Sep 17 00:00:00 2001 From: marko1616 Date: Sat, 4 Jan 2025 01:30:38 +0800 Subject: [PATCH 3/4] Label mask. --- src/llamafactory/train/distilling/trainer.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/train/distilling/trainer.py b/src/llamafactory/train/distilling/trainer.py index 28920e2b56..e67a26ce34 100644 --- a/src/llamafactory/train/distilling/trainer.py +++ b/src/llamafactory/train/distilling/trainer.py @@ -98,6 +98,8 @@ def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def compute_loss( self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False, **kwargs ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: + labels = inputs.get("labels") + padding_mask = labels.eq(-100) label_loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs) with torch.no_grad(): teacher_outputs = self.teacher_model(**inputs) @@ -108,8 +110,15 @@ def compute_loss( student_logprob = torch.nn.functional.log_softmax( outputs.logits / self.finetuning_args.distilling_temperature, dim=-1 ) - kl_loss = teacher_prob * (teacher_prob.log() - student_logprob) - loss = self.finetuning_args.distilling_lambda * kl_loss.mean() + label_loss + kl_losses = (teacher_prob * (teacher_prob.log() - student_logprob)).sum(dim=-1) + kl_losses.masked_fill_(padding_mask, 0) + num_active_elements = padding_mask.numel() - padding_mask.long().sum() + loss = ( + self.finetuning_args.distilling_lambda + * kl_losses.mean() + / (num_active_elements * student_logprob.shape[-1]) + + label_loss + ) if kwargs.get("num_items_in_batch") and not getattr(self, "model_accepts_loss_kwargs", False): loss = loss / self.args.gradient_accumulation_steps From 779a0debfebb20445881281d5a3740abadad583c Mon Sep 17 00:00:00 2001 From: marko1616 Date: Sat, 11 Jan 2025 16:34:35 +0800 Subject: [PATCH 4/4] Remove predict. --- src/llamafactory/hparams/finetuning_args.py | 8 +- .../train/{distilling => distil}/__init__.py | 4 +- .../train/{distilling => distil}/trainer.py | 114 +++++++----------- .../train/{distilling => distil}/workflow.py | 6 +- src/llamafactory/train/tuner.py | 6 +- 5 files changed, 53 insertions(+), 85 deletions(-) rename src/llamafactory/train/{distilling => distil}/__init__.py (89%) rename src/llamafactory/train/{distilling => distil}/trainer.py (58%) rename src/llamafactory/train/{distilling => distil}/workflow.py (98%) diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 58f6be5d5a..ebfddebe5f 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -127,9 +127,9 @@ class LoraArguments: @dataclass -class DistillingArguments: +class DistillationArguments: r""" - Arguments pertaining to the distilling training. + Arguments pertaining to the distillation training. """ distilling_lambda: float = field( @@ -364,7 +364,7 @@ class FinetuningArguments( GaloreArguments, BAdamArgument, SwanLabArguments, - DistillingArguments, + DistillationArguments, ): r""" Arguments pertaining to which techniques we are going to fine-tuning with. @@ -374,7 +374,7 @@ class FinetuningArguments( default=False, metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."}, ) - stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "distilling"] = field( + stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto", "distillation"] = field( default="sft", metadata={"help": "Which stage will be performed in training."}, ) diff --git a/src/llamafactory/train/distilling/__init__.py b/src/llamafactory/train/distil/__init__.py similarity index 89% rename from src/llamafactory/train/distilling/__init__.py rename to src/llamafactory/train/distil/__init__.py index ca95c88b54..53311fcdc2 100644 --- a/src/llamafactory/train/distilling/__init__.py +++ b/src/llamafactory/train/distil/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .workflow import run_distilling +from .workflow import run_distillation -__all__ = ["run_distilling"] +__all__ = ["run_distillation"] diff --git a/src/llamafactory/train/distilling/trainer.py b/src/llamafactory/train/distil/trainer.py similarity index 58% rename from src/llamafactory/train/distilling/trainer.py rename to src/llamafactory/train/distil/trainer.py index e67a26ce34..7c93f90bd5 100644 --- a/src/llamafactory/train/distilling/trainer.py +++ b/src/llamafactory/train/distil/trainer.py @@ -15,27 +15,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -import os +from copy import deepcopy from types import MethodType -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union -import numpy as np import torch from transformers import Seq2SeqTrainer from typing_extensions import override from ...extras import logging -from ...extras.constants import IGNORE_INDEX from ...extras.packages import is_transformers_version_greater_than from ..callbacks import SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler if TYPE_CHECKING: - from torch.utils.data import Dataset from transformers import PreTrainedModel, PreTrainedTokenizer, ProcessorMixin - from transformers.trainer import PredictionOutput from ...hparams import FinetuningArguments @@ -43,7 +38,7 @@ logger = logging.get_logger(__name__) -class CustomDistillingTrainer(Seq2SeqTrainer): +class CustomDistillationTrainer(Seq2SeqTrainer): r""" Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE. """ @@ -60,9 +55,17 @@ def __init__( else: self.processing_class: "PreTrainedTokenizer" = kwargs.get("tokenizer") - self.teacher_model = teacher_model - super().__init__(**kwargs) + + if teacher_model is not None: + if self.is_deepspeed_enabled: + if not ( + getattr(teacher_model, "is_loaded_in_8bit", False) + or getattr(teacher_model, "is_loaded_in_4bit", False) + ): # quantized models are already set on the correct device + self.teacher_model = self._prepare_deepspeed(teacher_model) + else: + self.teacher_model = self.accelerator.prepare_model(teacher_model, evaluation_mode=True) self.finetuning_args = finetuning_args if processor is not None: @@ -101,6 +104,7 @@ def compute_loss( labels = inputs.get("labels") padding_mask = labels.eq(-100) label_loss, outputs = super().compute_loss(model, inputs, return_outputs=True, **kwargs) + self.teacher_model.eval() with torch.no_grad(): teacher_outputs = self.teacher_model(**inputs) # Shape: (batch_size, seq_len, vocab_size) @@ -125,66 +129,30 @@ def compute_loss( return (loss, outputs) if return_outputs else loss - @override - def prediction_step( - self, - model: "torch.nn.Module", - inputs: Dict[str, Union["torch.Tensor", Any]], - prediction_loss_only: bool, - ignore_keys: Optional[List[str]] = None, - **gen_kwargs, - ) -> Tuple[Optional[float], Optional["torch.Tensor"], Optional["torch.Tensor"]]: - r""" - Removes the prompt part in the generated tokens. - - Subclass and override to inject custom behavior. - """ - if self.args.predict_with_generate: # do not pass labels to model when generate - labels = inputs.pop("labels", None) - else: - labels = inputs.get("labels") - - loss, generated_tokens, _ = super().prediction_step( - model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys, **gen_kwargs - ) - if generated_tokens is not None and self.args.predict_with_generate: - generated_tokens[:, : inputs["input_ids"].size(-1)] = self.processing_class.pad_token_id - generated_tokens = generated_tokens.contiguous() - - return loss, generated_tokens, labels - - def save_predictions( - self, dataset: "Dataset", predict_results: "PredictionOutput", skip_special_tokens: bool = True - ) -> None: - r""" - Saves model predictions to `output_dir`. - - A custom behavior that not contained in Seq2SeqTrainer. - """ - if not self.is_world_process_zero(): - return - - output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") - logger.info_rank0(f"Saving prediction results to {output_prediction_file}") - - labels = np.where( - predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.processing_class.pad_token_id - ) - preds = np.where( - predict_results.predictions != IGNORE_INDEX, - predict_results.predictions, - self.processing_class.pad_token_id, - ) - - for i in range(len(preds)): - pad_len = np.nonzero(preds[i] != self.processing_class.pad_token_id)[0] - if len(pad_len): # move pad token to last - preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) - - decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False) - decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=skip_special_tokens) - decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=skip_special_tokens) - - with open(output_prediction_file, "w", encoding="utf-8") as f: - for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels): - f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n") + def _prepare_deepspeed(self, model: "PreTrainedModel"): + import deepspeed # type: ignore + + deepspeed_plugin = self.accelerator.state.deepspeed_plugin + config_kwargs = deepcopy(deepspeed_plugin.deepspeed_config) + + if model is not None: + if hasattr(model, "config"): + hidden_size = ( + max(model.config.hidden_sizes) + if getattr(model.config, "hidden_sizes", None) + else getattr(model.config, "hidden_size", None) + ) + if hidden_size is not None and config_kwargs["zero_optimization"]["stage"] == 3: + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + } + ) + + if config_kwargs["zero_optimization"]["stage"] != 3: + config_kwargs["zero_optimization"]["stage"] = 0 + model, *_ = deepspeed.initialize(model=model, config=config_kwargs) + model.eval() + return model diff --git a/src/llamafactory/train/distilling/workflow.py b/src/llamafactory/train/distil/workflow.py similarity index 98% rename from src/llamafactory/train/distilling/workflow.py rename to src/llamafactory/train/distil/workflow.py index e24dcb5dcb..cb64865852 100644 --- a/src/llamafactory/train/distilling/workflow.py +++ b/src/llamafactory/train/distil/workflow.py @@ -25,7 +25,7 @@ from ...hparams import FinetuningArguments, ModelArguments from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push -from .trainer import CustomDistillingTrainer +from .trainer import CustomDistillationTrainer if TYPE_CHECKING: @@ -37,7 +37,7 @@ logger = get_logger(__name__) -def run_distilling( +def run_distillation( model_args: "ModelArguments", data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments", @@ -87,7 +87,7 @@ def run_distilling( training_args.remove_unused_columns = False # important for multimodal dataset # Initialize our Trainer - trainer = CustomDistillingTrainer( + trainer = CustomDistillationTrainer( model=model, teacher_model=teacher_model, args=training_args, diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index f5f0df7a0d..527239d544 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -25,7 +25,7 @@ from ..hparams import get_infer_args, get_train_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback, PissaConvertCallback, ReporterCallback -from .distilling import run_distilling +from .distil import run_distillation from .dpo import run_dpo from .kto import run_kto from .ppo import run_ppo @@ -66,8 +66,8 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallb run_dpo(model_args, data_args, training_args, finetuning_args, callbacks) elif finetuning_args.stage == "kto": run_kto(model_args, data_args, training_args, finetuning_args, callbacks) - elif finetuning_args.stage == "distilling": - run_distilling(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) + elif finetuning_args.stage == "distillation": + run_distillation(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) else: raise ValueError(f"Unknown task: {finetuning_args.stage}.")