Skip to content

Commit

Permalink
[SFT] fix neftune_noise_alpha in SFTTrainer (#1841)
Browse files Browse the repository at this point in the history
* fix neftune_noise_alpha

* del neftune_noise_alpha first

* check len after removing handle

* make sure we do not load twice

* Update trl/trainer/sft_trainer.py

Co-authored-by: lewtun <[email protected]>

* remove neftune from SFTTrainer as the superclass has it now

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: lewtun <[email protected]>
  • Loading branch information
3 people authored Sep 19, 2024
1 parent 3cec013 commit 9fb871f
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 86 deletions.
4 changes: 2 additions & 2 deletions tests/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,7 @@ def test_sft_trainer_with_model_neftune(self):
eval_dataset=self.eval_dataset,
)

trainer.model = trainer._trl_activate_neftune(trainer.model)
trainer.model = trainer._activate_neftune(trainer.model)

device = trainer.model.get_input_embeddings().weight.device
trainer.model.train()
Expand Down Expand Up @@ -992,7 +992,7 @@ def test_peft_sft_trainer_neftune(self):
peft_config=peft_config,
)

trainer.model = trainer._trl_activate_neftune(trainer.model)
trainer.model = trainer._activate_neftune(trainer.model)

assert isinstance(trainer.model, PeftModel)

Expand Down
5 changes: 0 additions & 5 deletions trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ class SFTConfig(TrainingArguments):
dataset_batch_size (`Union[int, None]`, *optional*, defaults to `1000`):
Number of examples to tokenize per batch. If `dataset_batch_size <= 0` or `dataset_batch_size is None`,
tokenizes the full dataset as a single batch.
neftune_noise_alpha (`Optional[float]`, *optional*, defaults to `None`):
Scale of the noise for NEFTune embeddings. The [NEFTune paper](https://huggingface.co/papers/2310.05914)
suggests using values between `5` and `15`. If set to `None`, NEFTune is not activated. Activating NEFTune
can significantly improve model performance for instruction fine-tuning.
model_init_kwargs (`Optional[Dict[str, Any]]`, *optional*, defaults to `None`):
Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
string.
Expand All @@ -65,7 +61,6 @@ class SFTConfig(TrainingArguments):
max_seq_length: Optional[int] = None
dataset_num_proc: Optional[int] = None
dataset_batch_size: int = 1000
neftune_noise_alpha: Optional[float] = None
model_init_kwargs: Optional[Dict[str, Any]] = None
dataset_kwargs: Optional[Dict[str, Any]] = None
eval_packing: Optional[bool] = None
Expand Down
56 changes: 6 additions & 50 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
PreTrainedTokenizerBase,
Trainer,
)
from transformers.modeling_utils import unwrap_model
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from transformers.utils import is_peft_available
Expand All @@ -45,7 +44,6 @@
from .utils import (
ConstantLengthDataset,
DataCollatorForCompletionOnlyLM,
neftune_post_forward_hook,
peft_module_casting_to_bf16,
trl_sanitze_kwargs_for_tagging,
)
Expand Down Expand Up @@ -154,6 +152,12 @@ def __init__(
args_as_dict.update({k: getattr(args, k) for k in args_as_dict.keys() if k.endswith("_token")})
args = SFTConfig(**args_as_dict)

if neftune_noise_alpha is not None:
warnings.warn(
"You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
)
args.neftune_noise_alpha = neftune_noise_alpha

if model_init_kwargs is not None:
warnings.warn(
"You passed `model_init_kwargs` to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
Expand Down Expand Up @@ -307,16 +311,6 @@ def make_inputs_require_grad(module, input, output):
args.dataset_batch_size = dataset_batch_size
self.dataset_batch_size = args.dataset_batch_size

self._trainer_supports_neftune = hasattr(args, "neftune_noise_alpha")
if neftune_noise_alpha is not None and self._trainer_supports_neftune:
args.neftune_noise_alpha = neftune_noise_alpha
warnings.warn(
"You passed a `neftune_noise_alpha` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
)
# self.neftune_noise_alpha is done at Trainer level
elif not self._trainer_supports_neftune:
self.neftune_noise_alpha = neftune_noise_alpha

if dataset_text_field is not None:
warnings.warn(
"You passed a `dataset_text_field` argument to the SFTTrainer, the value you passed will override the one in the `SFTConfig`."
Expand Down Expand Up @@ -425,28 +419,6 @@ def make_inputs_require_grad(module, input, output):
elif self.args.max_steps == -1 and args.packing:
self.train_dataset.infinite = False

@wraps(Trainer.train)
def train(self, *args, **kwargs):
# Activate neftune right before training.
if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
self.model = self._trl_activate_neftune(self.model)

output = super().train(*args, **kwargs)

# After training we make sure to retrieve back the original forward pass method
# for the embedding layer by removing the forward post hook.
if self.neftune_noise_alpha is not None and not self._trainer_supports_neftune:
unwrapped_model = unwrap_model(self.model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()

self.neftune_hook_handle.remove()
del embeddings.neftune_noise_alpha

return output

@wraps(Trainer.push_to_hub)
def push_to_hub(
self,
Expand Down Expand Up @@ -639,19 +611,3 @@ def data_generator(constant_length_iterator):
raise ValueError(
"You need to pass a `dataset_text_field` or `formatting_func` argument to the SFTTrainer if you want to use the `ConstantLengthDataset`."
)

def _trl_activate_neftune(self, model):
r"""
Activates the neftune as presented in this code: https://github.com/neelsjain/NEFTune and paper: https://huggingface.co/papers/2310.05914
Since in transformers Trainer we do have an `_activate_neftune` method, we need to rename this method to avoid conflicts.
"""
unwrapped_model = unwrap_model(model)
if is_peft_available() and isinstance(unwrapped_model, PeftModel):
embeddings = unwrapped_model.base_model.model.get_input_embeddings()
else:
embeddings = unwrapped_model.get_input_embeddings()

embeddings.neftune_noise_alpha = self.neftune_noise_alpha
hook_handle = embeddings.register_forward_hook(neftune_post_forward_hook)
self.neftune_hook_handle = hook_handle
return model
29 changes: 0 additions & 29 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -826,35 +826,6 @@ def get_stats(self):
return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()}


def neftune_post_forward_hook(module, input, output):
"""
Implements the NEFTune forward pass for the model using forward hooks. Note this works only for
torch.nn.Embedding layers. This method is slightly adapted from the original source code
that can be found here: https://github.com/neelsjain/NEFTune
Simply add it to your model as follows:
```python
model = ...
model.embed_tokens.neftune_noise_alpha = 0.1
model.embed_tokens.register_forward_hook(neftune_post_forward_hook)
```
Args:
module (`torch.nn.Module`):
The embedding module where the hook is attached. Note that you need to set
`module.neftune_noise_alpha` to the desired noise alpha value.
input (`torch.Tensor`):
The input tensor to the model.
output (`torch.Tensor`):
The output tensor of the model (i.e. the embeddings).
"""
if module.training:
dims = torch.tensor(output.size(1) * output.size(2))
mag_norm = module.neftune_noise_alpha / torch.sqrt(dims)
output = output + torch.zeros_like(output).uniform_(-mag_norm, mag_norm)
return output


def peft_module_casting_to_bf16(model):
for name, module in model.named_modules():
if isinstance(module, torch.nn.LayerNorm) or "norm" in name:
Expand Down

0 comments on commit 9fb871f

Please sign in to comment.