Skip to content

Commit

Permalink
PEFT support for Online DPO (#2041)
Browse files Browse the repository at this point in the history
* Promote `PPOv2Trainer` and `PPOv2Config` to top-level import

* Deprecate `PPOTrainer` and `PPOConfig`

* changes

* Revert "Promote `PPOv2Trainer` and `PPOv2Config` to top-level import"

This reverts commit 96ae02a.

* Revert "Deprecate `PPOTrainer` and `PPOConfig`"

This reverts commit 65990de.

* peft

* peft

* try to simplify

* revert utils changes

* update dpo script

* peft

* style

* revert gitignore

* test_online_dpo_peft

* ref model

* peft example command

* typo

* remove param.requires_grad = False for the reward model

* make `model` required arg

* update example script

* update xpo trainer

* Update examples/scripts/dpo_online.py

Co-authored-by: lewtun <[email protected]>

* Update examples/scripts/dpo_online.py

Co-authored-by: lewtun <[email protected]>

* merge and unload

---------

Co-authored-by: lewtun <[email protected]>
  • Loading branch information
qgallouedec and lewtun authored Sep 13, 2024
1 parent 88bede6 commit ebc85b2
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 46 deletions.
37 changes: 27 additions & 10 deletions examples/scripts/dpo_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,23 @@
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-7 \
--output_dir pythia-1b-tldr-online-dpo \
--per_device_train_batch_size 4 \
--gradient_accumulation_steps 32 \
--num_train_epochs 3 \
--max_new_tokens 53 \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 16 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0
With LoRA:
python examples/scripts/dpo_online.py \
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \
--dataset_name trl-lib/tldr \
--learning_rate 5.0e-6 \
--output_dir pythia-1b-tldr-online-dpo \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 8 \
--warmup_ratio 0.1 \
--missing_eos_penalty 1.0 \
--push_to_hub
--use_peft
"""

import torch
Expand All @@ -40,10 +50,12 @@
OnlineDPOConfig,
OnlineDPOTrainer,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
maybe_apply_chat_template,
LogCompletionsCallback,
)

from trl.commands.cli_utils import TrlParser
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE

Expand All @@ -70,19 +82,24 @@
model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)
ref_model = AutoModelForCausalLM.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs
)

reward_model = AutoModelForSequenceClassification.from_pretrained(
training_args.reward_model_path, num_labels=1, trust_remote_code=model_config.trust_remote_code
training_args.reward_model_path,
num_labels=1,
trust_remote_code=model_config.trust_remote_code,
**model_kwargs,
)

tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path,
padding_side="left",
trust_remote_code=model_config.trust_remote_code,
**model_kwargs,
)
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

dataset = load_dataset(args.dataset_name)

Expand All @@ -93,12 +110,12 @@

trainer = OnlineDPOTrainer(
model=model,
ref_model=ref_model,
reward_model=reward_model,
args=training_args,
train_dataset=dataset[args.dataset_train_split],
eval_dataset=dataset[args.dataset_test_split],
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)
generation_config = GenerationConfig(
max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
Expand Down
127 changes: 122 additions & 5 deletions tests/test_online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@

from datasets import Dataset
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl import OnlineDPOConfig, OnlineDPOTrainer


if is_peft_available():
from peft import LoraConfig, get_peft_model


class TestOnlineDPOTrainer(unittest.TestCase):
def setUp(self):
self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab"
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id)
self.reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)
self.tokenizer.pad_token = self.tokenizer.eos_token
Expand Down Expand Up @@ -67,22 +74,19 @@ def setUp(self):
# fmt: on
self.dummy_dataset = Dataset.from_dict(dummy_dataset_dict)

def test_online_dpo_trainer_training(self):
def test_training(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = OnlineDPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
remove_unused_columns=False,
gradient_accumulation_steps=1,
learning_rate=9e-1,
learning_rate=5.0e-7,
eval_strategy="steps",
report_to="none",
)

trainer = OnlineDPOTrainer(
model=self.model,
ref_model=self.model,
reward_model=self.reward_model,
args=training_args,
tokenizer=self.tokenizer,
Expand All @@ -94,3 +98,116 @@ def test_online_dpo_trainer_training(self):

# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])

def test_training_with_ref_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = OnlineDPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
learning_rate=5.0e-7,
eval_strategy="steps",
report_to="none",
)

trainer = OnlineDPOTrainer(
model=self.model,
ref_model=self.ref_model,
reward_model=self.reward_model,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=self.dummy_dataset,
eval_dataset=self.dummy_dataset,
)

trainer.train()

# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])

@require_peft
def test_training_with_peft(self):
lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = OnlineDPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
learning_rate=5.0e-7,
eval_strategy="steps",
report_to="none",
)

trainer = OnlineDPOTrainer(
model=self.model,
reward_model=self.reward_model,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=self.dummy_dataset,
eval_dataset=self.dummy_dataset,
peft_config=lora_config,
)

trainer.train()

# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])

@require_peft
def test_training_with_peft_and_ref_model(self):
lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = OnlineDPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
learning_rate=5.0e-7,
eval_strategy="steps",
report_to="none",
)

trainer = OnlineDPOTrainer(
model=self.model,
ref_model=self.ref_model,
reward_model=self.reward_model,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=self.dummy_dataset,
eval_dataset=self.dummy_dataset,
peft_config=lora_config,
)

trainer.train()

# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])

def test_training_with_peft_model_and_peft_config(self):
model_lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM")
model = get_peft_model(self.model, model_lora_config)
# we want only the "train adapter" to be trained
lora_train_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = OnlineDPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
learning_rate=5.0e-7,
eval_strategy="steps",
report_to="none",
)

trainer = OnlineDPOTrainer(
model=model,
reward_model=self.reward_model,
args=training_args,
tokenizer=self.tokenizer,
train_dataset=self.dummy_dataset,
eval_dataset=self.dummy_dataset,
peft_config=lora_train_config,
)

trainer.train()

# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])
Loading

0 comments on commit ebc85b2

Please sign in to comment.