Skip to content

Commit

Permalink
Use uniform config (#817)
Browse files Browse the repository at this point in the history
* Use uniform config

* quick fix

* refactor

* update docs
  • Loading branch information
vwxyzjn authored Oct 9, 2023
1 parent eda1f36 commit 95aea7c
Show file tree
Hide file tree
Showing 7 changed files with 232 additions and 304 deletions.
33 changes: 4 additions & 29 deletions docs/source/ddpo_trainer.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -37,38 +37,13 @@ Almost every configuration parameter has a default. There is only one commandlin
python stable_diffusion_tuning.py --hf_user_access_token <token>
```

Again, the script uses a small subset of parameters to configure the trainer. And all of these are configurable via the commandline.
It should be noted (in general) that because the trainer uses `accelerate` as a core component, some parameters are those of accelerate's.
The commandline flags that are associated with the example script's parameters are listed below.

|parameter|description|default|
| ---- | ---- | ---- |
|`--hf_hub_aesthetic_model_id`|The HuggingFace model hub id of the aesthetic scorer model|`"trl-lib/ddpo-aesthetic-predictor"`|
|`--hf_hub_aesthetic_model_filename`|The filename of the aesthetic scorer model |`"aesthetic-model.pth"`|
|`--pretrained_model`|The string id of the pretrained Stable Diffusion model|`"runwayml/stable-diffusion-v1-5"`|
|`--pretrained_revision`|The revision of the pretrained Stable Diffusion model|`"main"`|
|`--num_epochs`|The number of epochs to train for|`200`|
|`--train_batch_size`|The batch size to use for training|`3`|
|`--sample_batch_size`|The batch size to use for sampling|`6`|
|`--gradient_accumulation_steps`|The number of accelerator based gradient accumulation steps to use|`1`|
|`--sample_num_steps`| The number of steps to sample for|`50`|
|`--sample_num_batches_per_epoch`|The number of batches to sample per epoch|`4`|
|`--log_with`|The logger to use. Either `wandb` or `tensorboard`|`wandb`|
|`--per_prompt_stat_tracking`|Whether to track stats per prompt. If false, advantages will be calculated using the mean and std of the entire batch as opposed to tracking per prompt|`True`|
|`--per_prompt_stat_tracking_buffer_size`|The size of the buffer to use for tracking stats per prompt|`32`|
|`--tracker_project_name`|The name of the project for use on the tracking platform (wandb/tensorboard/etc) |`"stable_diffusion_training"`|
| `--logging_dir`|The directory to use for logging|`"logs"`|
| `--project_dir`|The directory to use for saving the model|`"save"`|
| `--automatic_checkpoint_naming`|Whether to automatically name model checkpoints|`True`|
| `--total_limit`| Number of checkpoints to keep before overwriting old ones|`5`|
| `--hf_hub_model_id`|The HuggingFace model hub id to use for saving the model|`"ddpo-finetuned-sd-model"`|
| `--hf_user_access_token`| The HuggingFace user access token|`None`|
To obtain the documentation of `stable_diffusion_tuning.py`, please run `python stable_diffusion_tuning.py --help`

The following are things to keep in mind (The code checks this for you as well) in general while configuring the trainer (beyond the use case of using the example script)

- The configurable sample batch size should be greater than or equal to the configurable training batch size
- The configurable sample batch size must be divisible by the configurable train batch size
- The configurable sample batch size must be divisible by both the configurable gradient accumulation steps and the configurable accelerator processes count
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) should be greater than or equal to the configurable training batch size (`--ddpo_config.train_batch_size=3`)
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by the configurable train batch size (`--ddpo_config.train_batch_size=3`)
- The configurable sample batch size (`--ddpo_config.sample_batch_size=6`) must be divisible by both the configurable gradient accumulation steps (`--ddpo_config.train_gradient_accumulation_steps=1`) and the configurable accelerator processes count

## Setting up the image logging hook function

Expand Down
127 changes: 66 additions & 61 deletions examples/scripts/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from dataclasses import dataclass, field
from typing import Optional

import tyro
from accelerate import Accelerator
from datasets import load_dataset
from peft import LoraConfig
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig, HfArgumentParser
from transformers import AutoModelForSequenceClassification, AutoTokenizer, BitsAndBytesConfig

from trl import RewardConfig, RewardTrainer

Expand All @@ -29,60 +30,80 @@

@dataclass
class ScriptArguments:
"""
Hyperparameters to fine-tune a reward model on a given dataset with the `RewardTrainer`.
"""

model_name: Optional[str] = field(default="facebook/opt-350m", metadata={"help": "the model name"})
dataset_name: Optional[str] = field(default="Anthropic/hh-rlhf", metadata={"help": "the dataset name"})
dataset_text_field: Optional[str] = field(default="text", metadata={"help": "the text field of the dataset"})
log_with: Optional[str] = field(default=None, metadata={"help": "use 'wandb' to log with wandb"})
logging_steps: Optional[int] = field(default=500, metadata={"help": "the number of update steps between two logs"})
eval_split: Optional[str] = field(
default="none", metadata={"help": "the dataset split to evaluate on; default to 'none' (no evaluation)"}
model_name: str = "facebook/opt-350m"
"""the model name"""
dataset_name: str = "Anthropic/hh-rlhf"
"""the dataset name"""
dataset_text_field: str = "text"
"""the text field of the dataset"""
eval_split: str = "none"
"""the dataset split to evaluate on; default to 'none' (no evaluation)"""
load_in_8bit: bool = False
"""load the model in 8 bits precision"""
load_in_4bit: bool = False
"""load the model in 4 bits precision"""
trust_remote_code: bool = True
"""Enable `trust_remote_code`"""
reward_config: RewardConfig = field(
default_factory=lambda: RewardConfig(
output_dir="output",
per_device_train_batch_size=64,
num_train_epochs=1,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
learning_rate=1.41e-5,
report_to="tensorboard",
remove_unused_columns=False,
optim="adamw_torch",
logging_steps=500,
evaluation_strategy="no",
max_length=512,
)
)
learning_rate: Optional[float] = field(default=1.41e-5, metadata={"help": "the learning rate"})
batch_size: Optional[int] = field(default=64, metadata={"help": "the batch size"})
num_train_epochs: Optional[int] = field(default=1, metadata={"help": "the number of training epochs"})
seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"})
gradient_accumulation_steps: Optional[int] = field(
default=16, metadata={"help": "the number of gradient accumulation steps"}
use_peft: bool = False
"""whether to use peft"""
peft_config: Optional[LoraConfig] = field(
default_factory=lambda: LoraConfig(
r=16,
lora_alpha=16,
bias="none",
task_type="CAUSAL_LM",
task_type="SEQ_CLS",
modules_to_save=["scores"],
),
)
gradient_checkpointing: Optional[bool] = field(default=True, metadata={"help": "Enable gradient checkpointing"})
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"})
trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
load_in_8bit: bool = False
"""load the model in 8 bits precision"""
load_in_4bit: bool = False
"""load the model in 4 bits precision"""


parser = HfArgumentParser(ScriptArguments)
script_args = parser.parse_args_into_dataclasses()[0]
args = tyro.cli(ScriptArguments)
args.reward_config.evaluation_strategy = "steps" if args.eval_split != "none" else "no"


# Step 1: Load the model
if script_args.load_in_8bit and script_args.load_in_4bit:
if args.load_in_8bit and args.load_in_4bit:
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif script_args.load_in_8bit or script_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
)
elif args.load_in_8bit or args.load_in_4bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=args.load_in_8bit, load_in_4bit=args.load_in_4bit)
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
else:
device_map = None
quantization_config = None

model = AutoModelForSequenceClassification.from_pretrained(
script_args.model_name,
args.model_name,
quantization_config=quantization_config,
device_map=device_map,
trust_remote_code=script_args.trust_remote_code,
trust_remote_code=args.trust_remote_code,
num_labels=1,
)

# Step 2: Load the dataset and pre-process it
tokenizer = AutoTokenizer.from_pretrained(script_args.model_name)
train_dataset = load_dataset(script_args.dataset_name, split="train")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
train_dataset = load_dataset(args.dataset_name, split="train")


# Tokenize chosen/rejected pairs of inputs
Expand All @@ -106,60 +127,44 @@ def preprocess_function(examples):
return new_examples


# Preprocess the dataset and filter out examples that are longer than script_args.max_length
# Preprocess the dataset and filter out examples that are longer than args.max_length
train_dataset = train_dataset.map(
preprocess_function,
batched=True,
num_proc=4,
)
train_dataset = train_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= script_args.seq_length
and len(x["input_ids_rejected"]) <= script_args.seq_length
lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length
and len(x["input_ids_rejected"]) <= args.reward_config.max_length
)

if script_args.eval_split == "none":
if args.eval_split == "none":
eval_dataset = None
else:
eval_dataset = load_dataset(script_args.dataset_name, split=script_args.eval_split)
eval_dataset = load_dataset(args.dataset_name, split=args.eval_split)

eval_dataset = eval_dataset.map(
preprocess_function,
batched=True,
num_proc=4,
)
eval_dataset = eval_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= script_args.seq_length
and len(x["input_ids_rejected"]) <= script_args.seq_length
lambda x: len(x["input_ids_chosen"]) <= args.reward_config.max_length
and len(x["input_ids_rejected"]) <= args.reward_config.max_length
)


# Step 3: Define the training arguments
training_args = RewardConfig(
output_dir=script_args.output_dir,
per_device_train_batch_size=script_args.batch_size,
num_train_epochs=script_args.num_train_epochs,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
gradient_checkpointing=script_args.gradient_checkpointing,
learning_rate=script_args.learning_rate,
report_to="wandb" if script_args.log_with == "wandb" else "tensorboard",
remove_unused_columns=False,
optim="adamw_torch",
logging_steps=script_args.logging_steps,
evaluation_strategy="steps" if script_args.eval_split != "none" else "no",
max_length=script_args.seq_length,
)

# Step 4: Define the LoraConfig
if script_args.use_peft:
peft_config = LoraConfig(r=16, lora_alpha=16, bias="none", task_type="SEQ_CLS", modules_to_save=["scores"])
if args.use_peft:
peft_config = args.peft_config
else:
peft_config = None

# Step 5: Define the Trainer
trainer = RewardTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
args=args.reward_config,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
peft_config=peft_config,
Expand Down
9 changes: 5 additions & 4 deletions examples/scripts/sentiment_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,10 @@ class ScriptArguments:
score_clip=None,
)
)
query_dataset: str = field(default="imdb", metadata={"help": "the dataset to query"})
use_seq2seq: bool = field(default=False, metadata={"help": "whether to use seq2seq models"})
use_peft: bool = field(default=False, metadata={"help": "whether to use peft"})
use_seq2seq: bool = False
"""whether to use seq2seq models"""
use_peft: bool = False
"""whether to use peft"""
peft_config: Optional[LoraConfig] = field(
default_factory=lambda: LoraConfig(
r=16,
Expand Down Expand Up @@ -111,7 +112,7 @@ def tokenize(sample):


# We retrieve the dataloader by calling the `build_dataset` function.
dataset = build_dataset(args.ppo_config, args.query_dataset)
dataset = build_dataset(args.ppo_config, args.ppo_config.query_dataset)


def collator(data):
Expand Down
Loading

0 comments on commit 95aea7c

Please sign in to comment.