Skip to content

Commit

Permalink
Added error when ref_model and model have same id (#2057)
Browse files Browse the repository at this point in the history
* Added error check to RLOO, PPOv2, OnlineDPO that ref_policy and policy should have different identities.

* Update online_dpo_trainer.py

Co-authored-by: lewtun <[email protected]>

* style

* extend to other trainers

* bco as well

* case models are strings

* add tests

* style

---------

Co-authored-by: Quentin Gallouédec <[email protected]>
Co-authored-by: lewtun <[email protected]>
Co-authored-by: Quentin Gallouédec <[email protected]>
  • Loading branch information
4 people authored Sep 17, 2024
1 parent 41fe228 commit e74dbf2
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 1 deletion.
20 changes: 20 additions & 0 deletions tests/test_bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,26 @@ def test_bco_trainer(self, name, pre_compute, eval_dataset):
if param.sum() != 0:
self.assertFalse(torch.equal(param.cpu(), new_param.cpu()))

def test_bco_trainer_with_ref_model_is_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = BCOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference")

with self.assertRaises(ValueError):
BCOTrainer(
model=self.model,
ref_model=self.model, # ref_model can't be the same as model
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset["train"],
)

def test_tokenize_and_process_tokens(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = BCOConfig(
Expand Down
22 changes: 21 additions & 1 deletion tests/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,26 @@ def test_dpo_trainer_without_providing_ref_model(self, rpo_alpha, _):
if param.sum() != 0:
assert not torch.equal(param, new_param)

def test_dpo_trainer_with_ref_model_is_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = DPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference")

with self.assertRaises(ValueError):
DPOTrainer(
model=self.model,
ref_model=self.model, # ref_model can't be the same as model
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset["train"],
)

@require_peft
def test_dpo_trainer_without_providing_ref_model_with_lora(self):
from peft import LoraConfig
Expand Down Expand Up @@ -473,7 +493,7 @@ def test_tr_dpo_trainer(self):

trainer = DPOTrainer(
model=self.model,
ref_model=self.model,
ref_model=self.ref_model,
beta=0.1,
args=training_args,
tokenizer=self.tokenizer,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,26 @@ def test_kto_trainer(self, name, loss_type, pre_compute, eval_dataset):
if param.sum() != 0:
self.assertFalse(torch.equal(param, new_param))

def test_kto_trainer_with_ref_model_is_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference")

with self.assertRaises(ValueError):
KTOTrainer(
model=self.model,
ref_model=self.model, # ref_model can't be the same as model
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset["train"],
)

def test_tokenize_and_process_tokens(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = KTOConfig(
Expand Down
20 changes: 20 additions & 0 deletions tests/test_online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,26 @@ def test_training_with_ref_model(self):
# Check if training loss is available
self.assertIn("train_loss", trainer.state.log_history[-1])

def test_ref_model_is_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
training_args = OnlineDPOConfig(
output_dir=tmp_dir,
per_device_train_batch_size=2,
max_steps=3,
report_to="none",
)

dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only")

with self.assertRaises(ValueError):
OnlineDPOTrainer(
model=self.model,
ref_model=self.model, # ref_model can't be the same as model
args=training_args,
tokenizer=self.tokenizer,
train_dataset=dummy_dataset["train"],
)

@require_peft
def test_training_with_peft(self):
lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/bco_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,12 @@ def __init__(
if type(args) is TrainingArguments:
raise ValueError("Please use `BCOConfig` instead `TrainingArguments`.")

if ref_model is model:
raise ValueError(
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
"same as `model`, you must mass a copy of it, or `None` if you use peft."
)

if args.model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,12 @@ def __init__(
reference_free: bool = False,
force_use_ref_model: bool = False,
):
if not isinstance(model, str) and ref_model is model:
raise ValueError(
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
"same as `model`, you must mass a copy of it, or `None` if you use peft."
)

if model_init_kwargs is not None:
warnings.warn(
"You passed `model_init_kwargs` to the DPOTrainer, the value you passed will override the one in the `DPOConfig`."
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/kto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,12 @@ def __init__(
if type(args) is TrainingArguments:
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")

if not isinstance(model, str) and ref_model is model:
raise ValueError(
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
"same as `model`, you must mass a copy of it, or `None` if you use peft."
)

if args.model_init_kwargs is None:
model_init_kwargs = {}
elif not isinstance(model, str):
Expand Down
8 changes: 8 additions & 0 deletions trl/trainer/online_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,14 @@ def __init__(
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
) -> None:
if ref_model is model:
raise ValueError(
"`model` and `ref_model` cannot be the same object. If you want `ref_model` to be the "
"same as `model`, either omit the `ref_model` argument or pass `None`."
)

self.ref_model = ref_model

if reward_model is not None and judge is not None:
warnings.warn(
"Both `reward_model` and `judge` are provided. Please choose provide only one of them. "
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/ppov2_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def __init__(
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
callbacks: Optional[List[TrainerCallback]] = None,
) -> None:
if ref_policy is policy:
raise ValueError(
"`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
"same as `policy`, you must mass a copy of it, or `None` if you use peft."
)

self.args = config
args = config
self.tokenizer = tokenizer
Expand Down
6 changes: 6 additions & 0 deletions trl/trainer/rloo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ def __init__(
optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
callbacks: Optional[List[TrainerCallback]] = None,
) -> None:
if ref_policy is policy:
raise ValueError(
"`policy` and `ref_policy` cannot be the same object. If you want `ref_policy` to be the "
"same as `policy`, you must mass a copy of it, or `None` if you use peft."
)

self.args = config
args = config
self.tokenizer = tokenizer
Expand Down

0 comments on commit e74dbf2

Please sign in to comment.