Skip to content

Commit

Permalink
Enable gradient checkpointing to be disabled for reward modelling (#725)
Browse files Browse the repository at this point in the history
* Enable gradient checkpointing to be disabled for reward modelling

* Update examples/scripts/reward_trainer.py

Co-authored-by: Leandro von Werra <[email protected]>

* Apply suggestions from code review

Co-authored-by: Younes Belkada <[email protected]>

* Tidy docs

* Remove commas

---------

Co-authored-by: Leandro von Werra <[email protected]>
Co-authored-by: Younes Belkada <[email protected]>
  • Loading branch information
3 people authored Sep 6, 2023
1 parent decc832 commit 453c4ec
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 28 deletions.
14 changes: 6 additions & 8 deletions docs/source/reward_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,27 @@ Check out a complete flexible example inside [`examples/scripts`](https://github

## Expected dataset format

The reward trainer expects a very specific format for the dataset. Since the model will be trained to predict which sentence is the most relevant, given two sentences. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:
The [`RewardTrainer`] expects a very specific format for the dataset since the model will be trained on pairs of examples to predict which of the two is preferred. We provide an example from the [`Anthropic/hh-rlhf`](https://huggingface.co/datasets/Anthropic/hh-rlhf) dataset below:

<div style="text-align: center">
<img src="https://huggingface.co/datasets/trl-internal-testing/example-images/resolve/main/images/rlhf-antropic-example.png", width="50%">
</div>

Therefore the final dataset object should contain two 4 entries at least if you use the default `RewardDataCollatorWithPadding` data collator. The entries should be named:
Therefore the final dataset object should contain two 4 entries at least if you use the default [`RewardDataCollatorWithPadding`] data collator. The entries should be named:

- `input_ids_chosen`
- `attention_mask_chosen`
- `input_ids_rejected`
- `attention_mask_rejected`

The `j` and `k` suffixes are used to denote the two sentences in the paired dataset.

## Using the `RewardTrainer`

After standardizing your dataset, you can use the `RewardTrainer` as a classic Hugging Face Trainer.
You should pass an `AutoModelForSequenceClassification` model to the `RewardTrainer`.
After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers.
You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training.

### Leveraging the `peft` library to train a reward model
### Leveraging 🤗 PEFT to train a reward model

Just pass a `peft_config` in the key word arguments of `RewardTrainer`, and the trainer should automatically take care of converting the model into a PEFT model!
Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model!

```python
from peft import LoraConfig, task_type
Expand Down
31 changes: 15 additions & 16 deletions examples/scripts/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dataclasses import dataclass, field
from typing import Optional

from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
Expand All @@ -26,15 +27,14 @@
tqdm.pandas()


# Define and parse arguments.
@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine with RewardTrainer
Hyperparameters to fine-tune a reward model on a given dataset with the `RewardTrainer`.
"""

model_name: Optional[str] = field(default="facebook/opt-350m", metadata={"help": "the model name"})
dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the model name"})
dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"})
dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"})
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
logging_steps: Optional[int] = field(default=500, metadata={"help": "the number of update steps between two logs"})
Expand All @@ -48,6 +48,7 @@ class ScriptArguments:
gradient_accumulation_steps: Optional[int] = field(
default=16, metadata={"help": "the number of gradient accumulation steps"}
)
gradient_checkpointing: Optional[bool] = field(default=True, metadata={"help": "Enable gradient checkpointing"})
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"})
Expand All @@ -65,8 +66,8 @@ class ScriptArguments:
quantization_config = BitsAndBytesConfig(
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
)
# This means: fit the entire model on the GPU:0
device_map = {"": 0}
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
else:
device_map = None
quantization_config = None
Expand All @@ -84,11 +85,8 @@ class ScriptArguments:
train_dataset = load_dataset(script_args.dataset_name, split="train")


# Turn the dataset into pairs of post + summaries, where text_j is the preferred question + answer and text_k is the other.
# Then tokenize the dataset.
# Tokenize chosen/rejected pairs of inputs
# Adapt this section to your needs for custom datasets


def preprocess_function(examples):
new_examples = {
"input_ids_chosen": [],
Expand All @@ -97,18 +95,18 @@ def preprocess_function(examples):
"attention_mask_rejected": [],
}
for chosen, rejected in zip(examples["chosen"], examples["rejected"]):
tokenized_j = tokenizer(chosen, truncation=True)
tokenized_k = tokenizer(rejected, truncation=True)
tokenized_chosen = tokenizer(chosen, truncation=True)
tokenized_rejected = tokenizer(rejected, truncation=True)

new_examples["input_ids_chosen"].append(tokenized_j["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_j["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_k["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_k["attention_mask"])
new_examples["input_ids_chosen"].append(tokenized_chosen["input_ids"])
new_examples["attention_mask_chosen"].append(tokenized_chosen["attention_mask"])
new_examples["input_ids_rejected"].append(tokenized_rejected["input_ids"])
new_examples["attention_mask_rejected"].append(tokenized_rejected["attention_mask"])

return new_examples


# preprocess the dataset and filter out QAs that are longer than script_args.max_length
# Preprocess the dataset and filter out examples that are longer than script_args.max_length
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
Expand Down Expand Up @@ -141,6 +139,7 @@ def preprocess_function(examples):
per_device_train_batch_size=script_args.batch_size,
num_train_epochs=script_args.num_train_epochs,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
gradient_checkpointing=script_args.gradient_checkpointing,
learning_rate=script_args.learning_rate,
report_to="wandb" if script_args.log_with == "wandb" else "tensorboard",
remove_unused_columns=False,
Expand Down
4 changes: 3 additions & 1 deletion examples/scripts/sentiment_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Optional

import torch
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
Expand Down Expand Up @@ -149,7 +150,8 @@ def collator(data):
task_type="CAUSAL_LM",
)
ref_model = None
device_map = {"": 0}
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}

model = trl_model_class.from_pretrained(
config.model_name,
Expand Down
5 changes: 3 additions & 2 deletions examples/scripts/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import Optional

import torch
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
Expand Down Expand Up @@ -75,8 +76,8 @@ class ScriptArguments:
quantization_config = BitsAndBytesConfig(
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
)
# This means: fit the entire model on the GPU:0
device_map = {"": 0}
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
torch_dtype = torch.bfloat16
else:
device_map = None
Expand Down
2 changes: 1 addition & 1 deletion trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
)
elif is_peft_available() and peft_config is not None:
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_quantized", False):
model = prepare_model_for_int8_training(model)
model = prepare_model_for_int8_training(model, use_gradient_checkpointing=args.gradient_checkpointing)

model = get_peft_model(model, peft_config)

Expand Down
8 changes: 8 additions & 0 deletions trl/trainer/training_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class RewardConfig(TrainingArguments):
Parameters:
max_length (`int`, *optional*, defaults to `None`):
The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator.
gradient_checkpointing (`bool`, *optional*, defaults to `True`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
"""

max_length: Optional[int] = field(
Expand All @@ -39,3 +41,9 @@ class RewardConfig(TrainingArguments):
"help": "The maximum length of the sequences in the batch. This argument is required if you want to use the default data collator."
},
)
gradient_checkpointing: Optional[bool] = field(
default=True,
metadata={
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
},
)

0 comments on commit 453c4ec

Please sign in to comment.