From dc2bd07408439691d7b90001de7351b8ebede1d8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 16 Sep 2024 13:46:52 +0200 Subject: [PATCH] Nash md (#1853) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * initial skeleton * initial config and class * move TrainerCallback to callbacks.py * initial trainer mockup * formatting * add back header * script with reward model * call ref policy forward with torch no_grad * fix api * clean up the configs * use the new API * fix typo * get get_reward without grads * remove unused no_grad calls * fix formatting * initial GeometricMixtureWrapper * Update trl/models/modeling_base.py Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> * undo changes to callback * GenerationMixin needs generation_config * calculate score with model and mixture model outputs * fix scores and mixture_scores tensors * undo * use interleaved version to calcuate chosen-rejected * Revert "use interleaved version to calcuate chosen-rejected" This reverts commit 4a63a60971a7db173d10771548f17f650d955c2a. * fix mixture scores * Fix global step * use mixture_coeff * record scores_margin only * fix del * First version of Nash MD trainer * undo * fix formatting * fix toc * initial refactorin * mixin fixes * fix refactoring * cleanup comments * add log_stats * add test * initial docs * fix logs * fix missing_eos_penalty * fix output_dir * add peft_config to docs and super * undo init changes * Update docs/source/_toctree.yml Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update trl/trainer/nash_md_config.py Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * add dataset format * add authors * add dynamic parameter callback * update test * fix comments * test GeometricMixtureWrapper * header * formatting * formatting * add paper and abstract * Update docs/source/nash_md_trainer.md Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * DynamicParameterCallback * drop callback in favor of getter * revert kto config change * revert kto config change * fix contribution * `coeff` to `coef` * log dynamic coefs * Update docs/source/nash_md_trainer.md * Update docs/source/nash_md_trainer.md * fix tests * use self.ref_model * one-line --------- Co-authored-by: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Co-authored-by: Daniil Tiapkin Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec --- docs/source/_toctree.yml | 2 + docs/source/dataset_formats.mdx | 1 + docs/source/nash_md_trainer.md | 54 +++ docs/source/xpo_trainer.mdx | 4 +- examples/scripts/nash_md.py | 132 +++++++ ...test_modeling_geometric_mixture_wrapper.py | 65 +++ tests/test_nash_md_trainer.py | 155 ++++++++ tests/test_xpo_trainer.py | 99 ++++- trl/__init__.py | 5 + trl/models/__init__.py | 4 +- trl/models/modeling_base.py | 57 ++- trl/trainer/__init__.py | 4 + trl/trainer/nash_md_config.py | 34 ++ trl/trainer/nash_md_trainer.py | 373 ++++++++++++++++++ trl/trainer/online_dpo_config.py | 8 +- trl/trainer/online_dpo_trainer.py | 24 +- trl/trainer/xpo_config.py | 7 +- trl/trainer/xpo_trainer.py | 67 ++-- 18 files changed, 1053 insertions(+), 42 deletions(-) create mode 100644 docs/source/nash_md_trainer.md create mode 100644 examples/scripts/nash_md.py create mode 100644 tests/test_modeling_geometric_mixture_wrapper.py create mode 100644 tests/test_nash_md_trainer.py create mode 100644 trl/trainer/nash_md_config.py create mode 100644 trl/trainer/nash_md_trainer.py diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 61b6604d01..8c290e32bb 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -36,6 +36,8 @@ title: Online DPO - local: xpo_trainer title: XPO + - local: nash_md_trainer + title: Nash MD - local: orpo_trainer title: ORPO - local: kto_trainer diff --git a/docs/source/dataset_formats.mdx b/docs/source/dataset_formats.mdx index 7ffeb9c23a..9b9dda92e8 100644 --- a/docs/source/dataset_formats.mdx +++ b/docs/source/dataset_formats.mdx @@ -199,6 +199,7 @@ Choosing the right dataset format depends on the task you are working on and the | [`DPOTrainer`] | Preference (explicit prompt) | | [`IterativeSFTTrainer`] | Unpaired preference | | [`KTOTrainer`] | Unpaired preference | +| [`NashMDTrainer`] | Prompt-only | | [`OnlineDPOTrainer`] | Prompt-only | | [`ORPOTrainer`] | Preference (explicit prompt) | | [`PPOv2Trainer`] | Tokenized language modeling | diff --git a/docs/source/nash_md_trainer.md b/docs/source/nash_md_trainer.md new file mode 100644 index 0000000000..5ff03d1e20 --- /dev/null +++ b/docs/source/nash_md_trainer.md @@ -0,0 +1,54 @@ +# Nash MD Trainer + +## Overview +Nash-MD was proposed in the paper [Nash Learning from Human Feedback](https://huggingface.co/papers/2312.00886) by Rémi Munos, [Michal Valko](https://huggingface.co/misovalko), Daniele Calandriello, Mohammad Gheshlaghi Azar, Mark Rowland, Daniel Guo, Yunhao Tang, Matthieu Geist, Thomas Mésnard, and Andrea Michi. + +The abstract from the paper is the following: + +> Reinforcement learning from human feedback (RLHF) has emerged as the main paradigm for aligning large language models (LLMs) with human preferences. Typically, RLHF involves the initial step of learning a reward model from human feedback, often expressed as preferences between pairs of text generations produced by a pre-trained LLM. Subsequently, the LLM's policy is fine-tuned by optimizing it to maximize the reward model through a reinforcement learning algorithm. However, an inherent limitation of current reward models is their inability to fully represent the richness of human preferences and their dependency on the sampling distribution. In this study, we introduce an alternative pipeline for the fine-tuning of LLMs using pairwise human feedback. Our approach entails the initial learning of a preference model, which is conditioned on two inputs given a prompt, followed by the pursuit of a policy that consistently generates responses preferred over those generated by any competing policy, thus defining the Nash equilibrium of this preference model. We term this approach Nash learning from human feedback (NLHF). In the context of a tabular policy representation, we present a novel algorithmic solution, Nash-MD, founded on the principles of mirror descent. This algorithm produces a sequence of policies, with the last iteration converging to the regularized Nash equilibrium. Additionally, we explore parametric representations of policies and introduce gradient descent algorithms for deep-learning architectures. To demonstrate the effectiveness of our approach, we present experimental results involving the fine-tuning of a LLM for a text summarization task. We believe NLHF offers a compelling avenue for preference learning and policy optimization with the potential of advancing the field of aligning LLMs with human preferences. + + +This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Daniil Tiapkin](https://huggingface.co/dtiapkin), [Pierre Ménard](https://huggingface.co/menardprr), Daniele Calandriello and [Quentin Gallouédec](https://huggingface.co/qgallouedec). + +## Get started + +To just run the Nash MD script to make sure this trainer can run, you can run the following command to train a Nash MD model with a dummy reward model. + +```bash +python examples/scripts/nash_md.py \ + --model_name_or_path EleutherAI/pythia-14m \ + --reward_model_path EleutherAI/pythia-14m \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-14m-tldr-nash-md \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 32 \ + --num_train_epochs 3 \ + --max_new_tokens 64 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 +``` + +## Explanation of the logged metrics + +The logged metrics are as follows: + +* `loss/score`: The mean reinforce score loss. +* `loss/kl_div`: The mean kl divergence loss. +* `objective/entropy`: The mean entropy of the model and reference data. +* `rewards/accuracies`: The accuracies of the Nash MD's implicit reward model. +* `rewards/chosen`: The mean scores (according to the reward model) of the model completions. +* `rewards/rejected`: The mean scores (according to the reward model) of the mixture completions. +* `rewards/margins`: The mean reward margin (according to reward model) between the chosen and mixture completions. +* `logps/chosen`: The mean log probabilities of the chosen completions. +* `logps/rejected`: The mean log probabilities of the reference completions. +* `val/model_contain_eos_token`: The amount of times the model's output contains the eos token. +* `val/ref_contain_eos_token`: The amount of times the mixture's output contains the eos token. + +## NashMDTrainer + +[[autodoc]] NashMDTrainer + +## NashMDConfig + +[[autodoc]] NashMDConfig diff --git a/docs/source/xpo_trainer.mdx b/docs/source/xpo_trainer.mdx index 13147d5890..2c55fd226c 100644 --- a/docs/source/xpo_trainer.mdx +++ b/docs/source/xpo_trainer.mdx @@ -20,11 +20,11 @@ python examples/scripts/xpo.py \ --reward_model_path EleutherAI/pythia-14m \ --dataset_name trl-lib/tldr \ --learning_rate 5.0e-7 \ - --output_dir pythia-1b-tldr-xpo \ + --output_dir pythia-14m-tldr-xpo \ --per_device_train_batch_size 4 \ --gradient_accumulation_steps 32 \ --num_train_epochs 3 \ - --max_new_tokens 53 \ + --max_new_tokens 64 \ --warmup_ratio 0.1 \ --missing_eos_penalty 1.0 ``` diff --git a/examples/scripts/nash_md.py b/examples/scripts/nash_md.py new file mode 100644 index 0000000000..6ad8068db9 --- /dev/null +++ b/examples/scripts/nash_md.py @@ -0,0 +1,132 @@ +# flake8: noqa + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +python examples/scripts/nash_md.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-1b-tldr-nash-md \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 32 \ + --num_train_epochs 3 \ + --max_new_tokens 64 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 \ + --push_to_hub + + +accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ + examples/scripts/nash_md.py \ + --model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ + --reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ + --dataset_name trl-lib/tldr \ + --learning_rate 5.0e-7 \ + --output_dir pythia-1b-tldr-nash-md \ + --per_device_train_batch_size 4 \ + --gradient_accumulation_steps 32 \ + --num_train_epochs 3 \ + --max_new_tokens 64 \ + --warmup_ratio 0.1 \ + --missing_eos_penalty 1.0 \ + --push_to_hub +""" + +import torch +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig +from accelerate import PartialState +from trl import ( + DPOScriptArguments, + ModelConfig, + NashMDConfig, + NashMDTrainer, + get_kbit_device_map, + get_quantization_config, + maybe_apply_chat_template, + LogCompletionsCallback, +) +from trl.commands.cli_utils import TrlParser +from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE + + +if __name__ == "__main__": + parser = TrlParser((DPOScriptArguments, NashMDConfig, ModelConfig)) + args, training_args, model_config = parser.parse_args_and_config() + args.gradient_checkpointing_kwargs = {"use_reentrant": True} + + torch_dtype = ( + model_config.torch_dtype + if model_config.torch_dtype in ["auto", None] + else getattr(torch, model_config.torch_dtype) + ) + quantization_config = get_quantization_config(model_config) + model_kwargs = dict( + revision=model_config.model_revision, + attn_implementation=model_config.attn_implementation, + torch_dtype=torch_dtype, + use_cache=False if training_args.gradient_checkpointing else True, + device_map=get_kbit_device_map() if quantization_config is not None else None, + quantization_config=quantization_config, + ) + + model = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + ) + ref_model = AutoModelForCausalLM.from_pretrained( + model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs + ) + reward_model = AutoModelForSequenceClassification.from_pretrained( + training_args.reward_model_path, num_labels=1, trust_remote_code=model_config.trust_remote_code + ) + tokenizer = AutoTokenizer.from_pretrained( + model_config.model_name_or_path, + padding_side="left", + trust_remote_code=model_config.trust_remote_code, + ) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + if tokenizer.chat_template is None: + tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE + + dataset = load_dataset(args.dataset_name) + + with PartialState().local_main_process_first(): + dataset = dataset.map( + maybe_apply_chat_template, num_proc=training_args.dataset_num_proc, fn_kwargs={"tokenizer": tokenizer} + ) + + trainer = NashMDTrainer( + model=model, + ref_model=ref_model, + reward_model=reward_model, + args=training_args, + train_dataset=dataset[args.dataset_train_split], + eval_dataset=dataset[args.dataset_test_split], + tokenizer=tokenizer, + ) + generation_config = GenerationConfig( + max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature + ) + completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) + trainer.add_callback(completions_callback) + # train the model + trainer.train() + + # save the model + trainer.save_model(training_args.output_dir) diff --git a/tests/test_modeling_geometric_mixture_wrapper.py b/tests/test_modeling_geometric_mixture_wrapper.py new file mode 100644 index 0000000000..227e1019da --- /dev/null +++ b/tests/test_modeling_geometric_mixture_wrapper.py @@ -0,0 +1,65 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch +from transformers import AutoModelForCausalLM, GenerationConfig + +from trl.models.modeling_base import GeometricMixtureWrapper, create_reference_model + + +class TestGeometricMixtureWrapper(unittest.TestCase): + def setUp(self): + self.model = AutoModelForCausalLM.from_pretrained("gpt2") + self.ref_model = create_reference_model(self.model) + self.generation_config = GenerationConfig.from_pretrained("gpt2") + self.mixture_coef = 0.5 + self.wrapper = GeometricMixtureWrapper( + self.model, self.ref_model, self.generation_config, mixture_coef=self.mixture_coef + ) + + def test_forward(self): + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + attention_mask = torch.ones_like(input_ids) + + output = self.wrapper(input_ids=input_ids, attention_mask=attention_mask) + + self.assertIsNotNone(output) + self.assertTrue(hasattr(output, "logits")) + self.assertEqual(output.logits.shape, (1, 5, self.model.config.vocab_size)) + + def test_mixture_coefficient(self): + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + attention_mask = torch.ones_like(input_ids) + + with torch.no_grad(): + model_output = self.model(input_ids=input_ids, attention_mask=attention_mask) + ref_model_output = self.ref_model(input_ids=input_ids, attention_mask=attention_mask) + wrapper_output = self.wrapper(input_ids=input_ids, attention_mask=attention_mask) + + expected_logits = torch.nn.functional.log_softmax( + self.mixture_coef * ref_model_output.logits + (1 - self.mixture_coef) * model_output.logits, dim=-1 + ) + + self.assertTrue(torch.allclose(wrapper_output.logits, expected_logits, atol=1e-5)) + + def test_prepare_inputs_for_generation(self): + input_ids = torch.tensor([[1, 2, 3, 4, 5]]) + attention_mask = torch.ones_like(input_ids) + + inputs = self.wrapper.prepare_inputs_for_generation(input_ids, attention_mask=attention_mask, use_cache=True) + + self.assertIn("input_ids", inputs) + self.assertIn("attention_mask", inputs) + self.assertFalse(inputs.get("use_cache", False)) diff --git a/tests/test_nash_md_trainer.py b/tests/test_nash_md_trainer.py new file mode 100644 index 0000000000..a7e9f685fa --- /dev/null +++ b/tests/test_nash_md_trainer.py @@ -0,0 +1,155 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import tempfile +import unittest + +from datasets import load_dataset +from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers.testing_utils import require_peft +from transformers.utils import is_peft_available + +from trl import NashMDConfig, NashMDTrainer + + +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + +class TestNashMDTrainer(unittest.TestCase): + def setUp(self): + self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" + self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1) + self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) + self.tokenizer.pad_token = self.tokenizer.eos_token + + def test_nash_md_trainer_training(self): + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = NashMDConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + remove_unused_columns=False, + gradient_accumulation_steps=1, + learning_rate=9e-1, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = NashMDTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_peft + def test_training_with_peft(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = NashMDConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = NashMDTrainer( + model=self.model, + reward_model=self.reward_model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_peft + def test_training_with_peft_and_ref_model(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = NashMDConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = NashMDTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + def test_training_with_peft_model_and_peft_config(self): + model_lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM") + model = get_peft_model(self.model, model_lora_config) + # we want only the "train adapter" to be trained + lora_train_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = NashMDConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = NashMDTrainer( + model=model, + reward_model=self.reward_model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_train_config, + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) diff --git a/tests/test_xpo_trainer.py b/tests/test_xpo_trainer.py index fc85c8f6f1..7b098e88de 100644 --- a/tests/test_xpo_trainer.py +++ b/tests/test_xpo_trainer.py @@ -16,14 +16,21 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer +from transformers.testing_utils import require_peft +from transformers.utils import is_peft_available from trl import XPOConfig, XPOTrainer +if is_peft_available(): + from peft import LoraConfig, get_peft_model + + class TestXPOTrainer(unittest.TestCase): def setUp(self): self.model_id = "trl-internal-testing/dummy-GPT2-correct-vocab" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) + self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) self.reward_model = AutoModelForSequenceClassification.from_pretrained("EleutherAI/pythia-14m", num_labels=1) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) self.tokenizer.pad_token = self.tokenizer.eos_token @@ -44,12 +51,102 @@ def test_xpo_trainer_training(self): trainer = XPOTrainer( model=self.model, - ref_model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_peft + def test_training_with_peft(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = XPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = XPOTrainer( + model=self.model, + reward_model=self.reward_model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + @require_peft + def test_training_with_peft_and_ref_model(self): + lora_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = XPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = XPOTrainer( + model=self.model, + ref_model=self.ref_model, + reward_model=self.reward_model, + args=training_args, + tokenizer=self.tokenizer, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + peft_config=lora_config, + ) + + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) + + def test_training_with_peft_model_and_peft_config(self): + model_lora_config = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM") + model = get_peft_model(self.model, model_lora_config) + # we want only the "train adapter" to be trained + lora_train_config = LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM") + with tempfile.TemporaryDirectory() as tmp_dir: + training_args = XPOConfig( + output_dir=tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") + + trainer = XPOTrainer( + model=model, reward_model=self.reward_model, args=training_args, tokenizer=self.tokenizer, train_dataset=dummy_dataset["train"], eval_dataset=dummy_dataset["test"], + peft_config=lora_train_config, ) trainer.train() diff --git a/trl/__init__.py b/trl/__init__.py index 674ac35158..3765aa295b 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -19,6 +19,7 @@ from typing import TYPE_CHECKING from .import_utils import _LazyModule, is_diffusers_available, OptionalDependencyNotAvailable + _import_structure = { "core": [ "set_seed", @@ -64,6 +65,8 @@ "BCOConfig", "BCOTrainer", "ModelConfig", + "NashMDConfig", + "NashMDTrainer", "OnlineDPOConfig", "OnlineDPOTrainer", "XPOConfig", @@ -169,6 +172,8 @@ BCOConfig, BCOTrainer, ModelConfig, + NashMDConfig, + NashMDTrainer, OnlineDPOConfig, OnlineDPOTrainer, XPOConfig, diff --git a/trl/models/__init__.py b/trl/models/__init__.py index 8a746ede60..7282fa1a6f 100644 --- a/trl/models/__init__.py +++ b/trl/models/__init__.py @@ -20,7 +20,7 @@ _import_structure = { - "modeling_base": ["PreTrainedModelWrapper", "create_reference_model"], + "modeling_base": ["PreTrainedModelWrapper", "create_reference_model", "GeometricMixtureWrapper"], "modeling_value_head": [ "AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead", @@ -42,7 +42,7 @@ ] if TYPE_CHECKING: - from .modeling_base import PreTrainedModelWrapper, create_reference_model + from .modeling_base import PreTrainedModelWrapper, create_reference_model, GeometricMixtureWrapper from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead from .utils import setup_chat_format, SUPPORTED_ARCHITECTURES diff --git a/trl/models/modeling_base.py b/trl/models/modeling_base.py index c55e854796..f9b94e3dcd 100644 --- a/trl/models/modeling_base.py +++ b/trl/models/modeling_base.py @@ -28,7 +28,7 @@ RepositoryNotFoundError, ) from safetensors.torch import load_file as safe_load_file -from transformers import PreTrainedModel +from transformers import GenerationMixin, PreTrainedModel from ..import_utils import is_npu_available, is_peft_available, is_transformers_greater_than, is_xpu_available @@ -676,3 +676,58 @@ def create_reference_model( logging.warning("Pattern passed or found, but no layers matched in the model. Check for a typo.") return ref_model.eval() + + +class GeometricMixtureWrapper(GenerationMixin): + r""" + Geometric Mixture generation wrapper that samples from the logits of two model's geometric mixture. + + Args: + model (`PreTrainedModel`): The model to be wrapped. + ref_model (`PreTrainedModel`): The reference model. + generation_config (`GenerationConfig`): The generation config. + mixture_coef (`float`, *optional* - default: 0.5): The mixture coefficient. + """ + + main_input_name = "input_ids" + _supports_cache_class = False + _supports_static_cache = False + + def __init__(self, model, ref_model, generation_config, mixture_coef=0.5, device=None): + super().__init__() + + self.model = model.eval() + self.config = model.config + self.ref_model = ref_model.eval() + self.generation_config = generation_config + self.mixture_coef = mixture_coef + self.device = device + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @torch.no_grad() + def forward(self, *args, **kwargs): + model_outputs = self.model(*args, **kwargs) + model_logits = model_outputs.logits + ref_model_logits = self.ref_model(*args, **kwargs).logits + + model_outputs.logits = torch.nn.functional.log_softmax( + self.mixture_coef * ref_model_logits + (1 - self.mixture_coef) * model_logits, dim=-1 + ) + + return model_outputs + + def prepare_inputs_for_generation(self, *args, **kwargs): + # turn off cache in the generation config + kwargs["use_cache"] = False + model_inputs = self.model.prepare_inputs_for_generation(*args, **kwargs) + _ = self.ref_model.prepare_inputs_for_generation(*args, **kwargs) + + return model_inputs + + def _validate_model_class(self): + self.model._validate_model_class() + + def _validate_model_kwargs(self, model_kwargs): + return self.model._validate_model_kwargs(model_kwargs) diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 9282cead8c..2608eb6294 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -42,6 +42,8 @@ "bco_config": ["BCOConfig"], "bco_trainer": ["BCOTrainer"], "model_config": ["ModelConfig"], + "nash_md_config": ["NashMDConfig"], + "nash_md_trainer": ["NashMDTrainer"], "online_dpo_config": ["OnlineDPOConfig"], "online_dpo_trainer": ["OnlineDPOTrainer"], "xpo_config": ["XPOConfig"], @@ -114,6 +116,8 @@ from .bco_config import BCOConfig from .bco_trainer import BCOTrainer from .model_config import ModelConfig + from .nash_md_config import NashMDConfig + from .nash_md_trainer import NashMDTrainer from .online_dpo_config import OnlineDPOConfig from .online_dpo_trainer import OnlineDPOTrainer from .xpo_config import XPOConfig diff --git a/trl/trainer/nash_md_config.py b/trl/trainer/nash_md_config.py new file mode 100644 index 0000000000..7a5a7c2ff5 --- /dev/null +++ b/trl/trainer/nash_md_config.py @@ -0,0 +1,34 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Union + +from trl.trainer.online_dpo_config import OnlineDPOConfig + + +@dataclass +class NashMDConfig(OnlineDPOConfig): + r""" + Configuration class for the [`NashMDTrainer`]. + + Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following: + + Parameters: + mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`): + Logit mixture coefficient for the model and reference model. + If a list of floats is provided then the mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the epochs. + """ + + mixture_coef: Union[float, List[float]] = 0.5 diff --git a/trl/trainer/nash_md_trainer.py b/trl/trainer/nash_md_trainer.py new file mode 100644 index 0000000000..b0118c237f --- /dev/null +++ b/trl/trainer/nash_md_trainer.py @@ -0,0 +1,373 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from datasets import Dataset, IterableDataset +from transformers import PreTrainedTokenizerBase, TrainerCallback +from transformers.modeling_utils import PreTrainedModel +from transformers.trainer_utils import EvalPrediction +from transformers.training_args import OptimizerNames +from transformers.utils import is_apex_available + +from ..models.modeling_base import GeometricMixtureWrapper +from ..models.utils import unwrap_model_for_generation +from .nash_md_config import NashMDConfig +from .online_dpo_trainer import OnlineDPOTrainer +from .utils import empty_cache, get_reward, truncate_right + + +if is_apex_available(): + from apex import amp + + +class NashMDTrainer(OnlineDPOTrainer): + r""" + Initialize NashMDTrainer as a subclass of [`OnlineDPOConfig`]. + + Args: + model (`transformers.PreTrainedModel`): + The model to train, preferably an `AutoModelForCausalLM`. + ref_model (`PreTrainedModelWrapper`): + Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no + reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. + reward_model (`transformers.PreTrainedModel`): + The reward model to score completions with, preferably an `AutoModelForSequenceClassification`. + judge (`BasePairwiseJudge`): + The judge to use for pairwise comparison of model completions. + args (`NashMDConfig`): + The NashMD config arguments to use for training. + data_collator (`transformers.DataCollator`): + The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used + which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. + train_dataset (`datasets.Dataset`): + The dataset to use for training. + eval_dataset (`datasets.Dataset`): + The dataset to use for evaluation. + tokenizer (`transformers.PreTrainedTokenizerBase`): + The tokenizer to use for training. This argument is required if you want to use the default data collator. + peft_config (`Dict`): + The peft config to use for training. + compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): + The function to use to compute the metrics. Must take a `EvalPrediction` and return + a dictionary string to metric values. + callbacks (`List[transformers.TrainerCallback]`): + The callbacks to use for training. + optimizers (`Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`): + The optimizer and scheduler to use for training. + preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`): + The function to use to preprocess the logits before computing the metrics. + """ + + _tag_names = ["trl", "nash-md"] + + def __init__( + self, + model: Union[PreTrainedModel, nn.Module] = None, + ref_model: Union[PreTrainedModel, nn.Module] = None, + reward_model: Optional[nn.Module] = None, + args: Optional[NashMDConfig] = None, + data_collator: Optional[Callable] = None, + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, + eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None, + peft_config: Optional[Dict] = None, + compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, + callbacks: Optional[List[TrainerCallback]] = None, + optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), + preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, + ) -> None: + super().__init__( + model=model, + ref_model=ref_model, + reward_model=reward_model, + args=args, + data_collator=data_collator, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tokenizer=tokenizer, + peft_config=peft_config, + compute_metrics=compute_metrics, + callbacks=callbacks, + optimizers=optimizers, + preprocess_logits_for_metrics=preprocess_logits_for_metrics, + ) + + self._mixture_coef = self.args.mixture_coef + + # Overwrite the stats dictionary to include NashMD specific statistics + self.stats = { + "logps/chosen": [], + "logps/rejected": [], + "rewards/chosen": [], + "rewards/rejected": [], + "loss/score": [], + "loss/kl_div": [], + "objective/entropy": [], + "rewards/margins": [], + "rewards/accuracies": [], + "val/model_contain_eos_token": [], + "val/ref_contain_eos_token": [], + "beta": [], + "mixture_coef": [], + } + + @property + def mixture_coef(self): + if isinstance(self._mixture_coef, list): + epoch = self.state.epoch + return self._mixture_coef[epoch] if epoch < len(self._mixture_coef) else self._mixture_coef[-1] + else: + return self._mixture_coef + + def _generate_completions(self, model, prompts): + with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: + model_output = unwrapped_model.generate( + input_ids=prompts["input_ids"], + attention_mask=prompts["attention_mask"], + generation_config=self.generation_config, + ) + + ref_model = model if self.ref_model is None else self.ref_model + with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model: + mixture_model = GeometricMixtureWrapper( + model=unwrapped_model, + ref_model=unwrapped_ref_model, + generation_config=self.generation_config, + mixture_coef=self.mixture_coef, + device=self.accelerator.device, + ) + + mixture_output = mixture_model.generate( + input_ids=prompts["input_ids"], + attention_mask=prompts["attention_mask"], + generation_config=self.generation_config, + ) + + return model_output, mixture_output + + def _process_completions(self, model_output, mixture_output, prompts): + context_length = prompts["input_ids"].shape[1] + + # Process model completions + model_completion_ids = model_output[:, context_length:] + model_completion_ids, model_completion_mask = truncate_right( + model_completion_ids, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id + ) + model_data = { + "input_ids": torch.cat((prompts["input_ids"], model_completion_ids), dim=1), + "attention_mask": torch.cat((prompts["attention_mask"], model_completion_mask), dim=1), + } + + # Process reference model completions + mixture_completion_ids = mixture_output[:, context_length:] + mixture_completion_ids, mixture_completion_mask = truncate_right( + mixture_completion_ids, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id + ) + mixture_data = { + "input_ids": torch.cat((prompts["input_ids"], mixture_completion_ids), dim=1), + "attention_mask": torch.cat((prompts["attention_mask"], mixture_completion_mask), dim=1), + } + + return model_data, mixture_data + + def _compute_rewards(self, model_data, mixture_data, context_length): + with torch.no_grad(): + _, model_scores, _ = get_reward( + self.reward_model, model_data["input_ids"], self.tokenizer.pad_token_id, context_length + ) + _, mixture_scores, _ = get_reward( + self.reward_model, mixture_data["input_ids"], self.tokenizer.pad_token_id, context_length + ) + + # Apply EOS penalty if needed + if self.args.missing_eos_penalty is not None: + model_contain_eos = torch.any(model_data["input_ids"] == self.tokenizer.eos_token_id, dim=-1) + mixture_contain_eos = torch.any(mixture_data["input_ids"] == self.tokenizer.eos_token_id, dim=-1) + model_scores[~model_contain_eos] -= self.args.missing_eos_penalty + mixture_scores[~mixture_contain_eos] -= self.args.missing_eos_penalty + + return model_scores, mixture_scores + + def _compute_logprobs(self, model, model_data, context_length): + def compute_logprobs_for_data(m, data): + output = m(data["input_ids"], attention_mask=data["attention_mask"]) + logits = output.logits[:, context_length - 1 : -1] + logprobs = F.log_softmax(logits, dim=-1) + token_logprobs = torch.gather(logprobs, 2, data["input_ids"][:, context_length:].unsqueeze(-1)).squeeze(-1) + return token_logprobs + + # Compute logprobs for model completions under the model + model_logprobs_model_data = compute_logprobs_for_data(model, model_data) + + # Compute logprobs of model completions under the reference model + with torch.no_grad(): + if self.ref_model is None: + with model.disable_adapter(): + ref_logprobs_model_data = compute_logprobs_for_data(model, model_data) + else: + ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data) + + # Mask padding tokens + model_padding_mask = model_data["attention_mask"][:, context_length:] == 0 + model_logprobs_model_data = model_logprobs_model_data.masked_fill(model_padding_mask, 0.0) + ref_logprobs_model_data = ref_logprobs_model_data.masked_fill(model_padding_mask, 0.0) + + return (model_logprobs_model_data, ref_logprobs_model_data) + + def _compute_losses( + self, + model_logprobs_model_data, + ref_logprobs_model_data, + model_data_scores, + mixture_data_scores, + ): + # Compute log probs + model_logprobs_model_data_sum = model_logprobs_model_data.sum(1) + ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1) + + # probability of the model data vs the mixture data + probability = F.sigmoid(model_data_scores - mixture_data_scores) + + # reinforce score where 0.5 is a control variate + score = (probability - 0.5) * model_logprobs_model_data_sum + + # kl divergence + kl_div = model_logprobs_model_data_sum - ref_logprobs_model_data_sum + + # final loss + loss = self.beta * kl_div - score + + return loss.mean(), score, kl_div + + def _log_statistics( + self, + model_data, + mixture_data, + model_logprobs_model_data, + ref_logprobs_model_data, + model_scores, + mixture_scores, + score, + kl_div, + context_length, + ): + # Helper function to gather and compute mean + def gather_mean(tensor): + return self.accelerator.gather(tensor).mean().item() + + # Log score + self.stats["loss/score"].append(gather_mean(score)) + # Log KL divergence + self.stats["loss/kl_div"].append(gather_mean(kl_div)) + + # Log logprobs + model_logprobs_model_data_sum = model_logprobs_model_data.sum(1) + ref_logprobs_model_data_sum = ref_logprobs_model_data.sum(1) + + self.stats["logps/chosen"].append(gather_mean(model_logprobs_model_data_sum)) + self.stats["logps/rejected"].append(gather_mean(ref_logprobs_model_data_sum)) + + # Log rewards + self.stats["rewards/chosen"].append(gather_mean(model_scores)) + self.stats["rewards/rejected"].append(gather_mean(mixture_scores)) + + # Calculate entropy for model data + entropy_model_data = -model_logprobs_model_data.sum(1) + self.stats["objective/entropy"].append(gather_mean(entropy_model_data)) + + # Calculate margins + margin = model_scores - mixture_scores + self.stats["rewards/margins"].append(gather_mean(margin)) + + # Calculate accuracy + accuracy = (margin > 0).float() + self.stats["rewards/accuracies"].append(gather_mean(accuracy)) + + # Log EOS token statistics + model_eos = (model_data["input_ids"][:, context_length:] == self.tokenizer.eos_token_id).any(dim=1) + mixture_eos = (mixture_data["input_ids"][:, context_length:] == self.tokenizer.eos_token_id).any(dim=1) + self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float())) + self.stats["val/ref_contain_eos_token"].append(gather_mean(mixture_eos.float())) + + # Log beta and mixture coef + self.stats["beta"].append(self.beta) + self.stats["mixture_coef"].append(self.mixture_coef) + + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: + model.train() + + # need the prompt_ only + inputs = self._prepare_inputs(inputs) + context_length = inputs["prompt_input_ids"].shape[1] + prompts = { + "input_ids": inputs["prompt_input_ids"], + "attention_mask": inputs["prompt_attention_mask"], + } + del inputs + + # Sample completions from both the model and the reference model + model_output, mixture_output = self._generate_completions(model, prompts) + + # Process model completions + model_data, mixture_data = self._process_completions(model_output, mixture_output, prompts) + + # Compute rewards + model_data_scores, mixture_data_scores = self._compute_rewards(model_data, mixture_data, context_length) + + # Compute logprobs + model_logprobs_model_data, ref_logprobs_model_data = self._compute_logprobs(model, model_data, context_length) + + # Compute loss + loss, score, kl_div = self._compute_losses( + model_logprobs_model_data, ref_logprobs_model_data, model_data_scores, mixture_data_scores + ) + + # Log everything + self._log_statistics( + model_data, + mixture_data, + model_logprobs_model_data.detach(), + ref_logprobs_model_data, + model_data_scores, + mixture_data_scores, + score.detach(), + kl_div.detach(), + context_length, + ) + + if ( + self.args.torch_empty_cache_steps is not None + and self.state.global_step % self.args.torch_empty_cache_steps == 0 + ): + empty_cache() + + kwargs = {} + # For LOMO optimizers you need to explicitly use the learning rate + if self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]: + kwargs["learning_rate"] = self._get_learning_rate() + + if self.args.n_gpu > 1: + loss = loss.mean() # mean() to average on multi-gpu parallel training + + if self.use_apex: + with amp.scale_loss(loss, self.optimizer) as scaled_loss: + scaled_loss.backward() + else: + self.accelerator.backward(loss, **kwargs) + + return loss.detach() / self.args.gradient_accumulation_steps diff --git a/trl/trainer/online_dpo_config.py b/trl/trainer/online_dpo_config.py index d930b852f4..c4554a6a55 100644 --- a/trl/trainer/online_dpo_config.py +++ b/trl/trainer/online_dpo_config.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Literal, Optional +from typing import List, Literal, Optional, Union from transformers import TrainingArguments @@ -38,10 +38,10 @@ class OnlineDPOConfig(TrainingArguments): Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive value. - beta (`float`, *optional*, defaults to `0.1`): + beta (`float` or `list[float]`, *optional*, defaults to `0.1`): Parameter controlling the deviation from the reference model. Higher β means less deviation from the reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in - the [paper](https://huggingface.co/papers/2310.12036). + the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is selected for each new epoch and the last β is used for the rest of the epochs. loss_type (`str`, *optional*, defaults to `"sigmoid"`): Type of loss to use. Possible values are: @@ -58,7 +58,7 @@ class OnlineDPOConfig(TrainingArguments): max_new_tokens: int = 64 temperature: float = 0.9 missing_eos_penalty: Optional[float] = None - beta: float = 0.1 + beta: Union[float, List[float]] = 0.1 loss_type: Literal["sigmoid", "ipo"] = "sigmoid" dataset_num_proc: Optional[int] = None disable_dropout: bool = True diff --git a/trl/trainer/online_dpo_trainer.py b/trl/trainer/online_dpo_trainer.py index 4e322eea81..e48c7446c6 100644 --- a/trl/trainer/online_dpo_trainer.py +++ b/trl/trainer/online_dpo_trainer.py @@ -94,6 +94,8 @@ class OnlineDPOTrainer(Trainer): The dataset to use for evaluation. tokenizer (`transformers.PreTrainedTokenizerBase`): The tokenizer to use for training. This argument is required if you want to use the default data collator. + peft_config (`Dict`): + The peft config to use for training. compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values. @@ -209,6 +211,7 @@ def __init__( "logps/chosen": [], "logps/rejected": [], "val/contain_eos_token": [], + "beta": [], } self.generation_config = GenerationConfig( @@ -233,6 +236,8 @@ def __init__( preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + self._beta = args.beta + # Placed after the super().__init__ because we need self.is_deepspeed_enabled and self.accelerator if self.is_deepspeed_enabled: if self.reward_model is not None: @@ -246,6 +251,14 @@ def __init__( if self.reward_model is not None: self.reward_model = self.reward_model.to(self.accelerator.device) + @property + def beta(self): + if isinstance(self._beta, list): + epoch = self.state.epoch + return self._beta[epoch] if epoch < len(self._beta) else self._beta[-1] + else: + return self._beta + @staticmethod def tokenize_row(feature, is_encoder_decoder: bool, tokenizer: PreTrainedTokenizerBase) -> Dict[str, Any]: """Tokenize a single row from a DPO specific dataset.""" @@ -421,9 +434,9 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, logits = pi_logratios - ref_logratios if self.args.loss_type == "sigmoid": - losses = -F.logsigmoid(self.args.beta * logits) + losses = -F.logsigmoid(self.beta * logits) elif self.args.loss_type == "ipo": - losses = (logits - 1 / (2 * self.args.beta)) ** 2 + losses = (logits - 1 / (2 * self.beta)) ** 2 else: raise NotImplementedError(f"invalid loss type {self.loss_type}") @@ -437,7 +450,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, kl = logprobs - ref_logprobs mean_kl = kl.sum(1).mean() self.stats["objective/kl"].append(self.accelerator.gather(mean_kl).mean().item()) - non_score_reward = (-self.args.beta * kl).sum(1) + non_score_reward = (-self.beta * kl).sum(1) mean_non_score_reward = non_score_reward.mean() self.stats["objective/non_score_reward"].append(self.accelerator.gather(mean_non_score_reward).mean().item()) rlhf_reward = scores + non_score_reward @@ -446,16 +459,17 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, self.stats["objective/entropy"].append(self.accelerator.gather(mean_entropy).mean().item()) scores_margin = scores[chosen_indices] - scores[rejected_indices] self.stats["objective/scores_margin"].append(self.accelerator.gather(scores_margin.mean()).mean().item()) - chosen_rewards = self.args.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) + chosen_rewards = self.beta * (chosen_logprobs_sum - chosen_ref_logprobs_sum) gathered_chosen_rewards = self.accelerator.gather(chosen_rewards) self.stats["rewards/chosen"].append(gathered_chosen_rewards.mean().item()) - rejected_rewards = self.args.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) + rejected_rewards = self.beta * (rejected_logprobs_sum - rejected_ref_logprobs_sum) gathered_rejected_rewards = self.accelerator.gather(rejected_rewards) self.stats["rewards/rejected"].append(gathered_rejected_rewards.mean().item()) margin = gathered_chosen_rewards - gathered_rejected_rewards self.stats["rewards/margins"].append(margin.mean().item()) accuracy = margin > 0 self.stats["rewards/accuracies"].append(accuracy.float().mean().item()) + self.stats["beta"].append(self.beta) if ( self.args.torch_empty_cache_steps is not None diff --git a/trl/trainer/xpo_config.py b/trl/trainer/xpo_config.py index 35c85c796c..86ae06da93 100644 --- a/trl/trainer/xpo_config.py +++ b/trl/trainer/xpo_config.py @@ -13,6 +13,7 @@ # limitations under the License. from dataclasses import dataclass +from typing import List, Union from trl.trainer.online_dpo_config import OnlineDPOConfig @@ -25,8 +26,8 @@ class XPOConfig(OnlineDPOConfig): Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following: Parameters: - alpha (`float`, *optional*, defaults to `1e-5`): - Weight of the XPO loss term. + alpha (`float` or `List[float]`, *optional*, defaults to `1e-5`): + Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch and the last alpha is used for the rest of the epochs. """ - alpha: float = 1e-5 + alpha: Union[float, List[float]] = 1e-5 diff --git a/trl/trainer/xpo_trainer.py b/trl/trainer/xpo_trainer.py index 3ac1604566..de2c2f35c0 100644 --- a/trl/trainer/xpo_trainer.py +++ b/trl/trainer/xpo_trainer.py @@ -26,11 +26,7 @@ from ..models.utils import unwrap_model_for_generation from .online_dpo_trainer import OnlineDPOTrainer -from .utils import ( - empty_cache, - get_reward, - truncate_right, -) +from .utils import empty_cache, get_reward, truncate_right from .xpo_config import XPOConfig @@ -63,6 +59,8 @@ class XPOTrainer(OnlineDPOTrainer): The dataset to use for evaluation. tokenizer (`transformers.PreTrainedTokenizerBase`): The tokenizer to use for training. This argument is required if you want to use the default data collator. + peft_config (`Dict`): + The peft config to use for training. compute_metrics (`Callable[[EvalPrediction], Dict]`, *optional*): The function to use to compute the metrics. Must take a `EvalPrediction` and return a dictionary string to metric values. @@ -86,6 +84,7 @@ def __init__( train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, + peft_config: Optional[Dict] = None, compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None, callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), @@ -100,12 +99,15 @@ def __init__( train_dataset=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, + peft_config=peft_config, compute_metrics=compute_metrics, callbacks=callbacks, optimizers=optimizers, preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) + self._alpha = self.args.alpha + # Overwrite the stats dictionary to include XPO specific statistics self.stats = { # Remove "non_score_reward", "rlhf_reward", "scores" @@ -127,9 +129,19 @@ def __init__( # Replace "contain_eos_token" by "model_contain_eos_token" and "ref_contain_eos_token" "val/model_contain_eos_token": [], "val/ref_contain_eos_token": [], + "alpha": [], + "beta": [], } - def _generate_completions(self, model, ref_model, prompts): + @property + def alpha(self): + if isinstance(self._alpha, list): + epoch = self.state.epoch + return self._alpha[epoch] if epoch < len(self._alpha) else self._alpha[-1] + else: + return self._alpha + + def _generate_completions(self, prompts, model): with unwrap_model_for_generation(model, self.accelerator) as unwrapped_model: model_output = unwrapped_model.generate( input_ids=prompts["input_ids"], @@ -137,6 +149,7 @@ def _generate_completions(self, model, ref_model, prompts): generation_config=self.generation_config, ) + ref_model = model if self.ref_model is None else self.ref_model with torch.no_grad(), unwrap_model_for_generation(ref_model, self.accelerator) as unwrapped_ref_model: ref_output = unwrapped_ref_model.generate( input_ids=prompts["input_ids"], @@ -172,14 +185,13 @@ def _process_completions(self, model_output, ref_output, prompts): return model_data, ref_data def _compute_rewards(self, model_data, ref_data, context_length): - all_input_ids = torch.cat([model_data["input_ids"], ref_data["input_ids"]], dim=0) - with torch.no_grad(): - _, all_scores, _ = get_reward( - self.reward_model, all_input_ids, self.tokenizer.pad_token_id, context_length + _, model_scores, _ = get_reward( + self.reward_model, model_data["input_ids"], self.tokenizer.pad_token_id, context_length + ) + _, ref_scores, _ = get_reward( + self.reward_model, ref_data["input_ids"], self.tokenizer.pad_token_id, context_length ) - - model_scores, ref_scores = all_scores.chunk(2) # Apply EOS penalty if needed if self.args.missing_eos_penalty is not None: @@ -190,7 +202,7 @@ def _compute_rewards(self, model_data, ref_data, context_length): return model_scores, ref_scores - def _compute_logprobs(self, model, ref_model, model_data, ref_data, context_length): + def _compute_logprobs(self, model, model_data, ref_data, context_length): def compute_logprobs_for_data(m, data): output = m(data["input_ids"], attention_mask=data["attention_mask"]) logits = output.logits[:, context_length - 1 : -1] @@ -205,8 +217,13 @@ def compute_logprobs_for_data(m, data): # Compute logprobs for reference model completions with torch.no_grad(): - ref_logprobs_model_data = compute_logprobs_for_data(ref_model, model_data) - ref_logprobs_ref_data = compute_logprobs_for_data(ref_model, ref_data) + if self.ref_model is None: + with model.disable_adapter(): + ref_logprobs_model_data = compute_logprobs_for_data(model, model_data) + ref_logprobs_ref_data = compute_logprobs_for_data(model, ref_data) + else: + ref_logprobs_model_data = compute_logprobs_for_data(self.ref_model, model_data) + ref_logprobs_ref_data = compute_logprobs_for_data(self.ref_model, ref_data) # Mask padding tokens model_padding_mask = model_data["attention_mask"][:, context_length:] == 0 @@ -244,14 +261,14 @@ def _compute_losses( logits = chosen_log_ratios - rejected_log_ratios if self.args.loss_type == "sigmoid": - dpo_losses = -F.logsigmoid(self.args.beta * logits) + dpo_losses = -F.logsigmoid(self.beta * logits) elif self.args.loss_type == "ipo": - dpo_losses = (logits - 1 / (2 * self.args.beta)) ** 2 + dpo_losses = (logits - 1 / (2 * self.beta)) ** 2 else: raise NotImplementedError(f"invalid loss type {self.args.loss_type}") # Compute XPO specific loss - xpo_losses = self.args.alpha * model_logprobs_ref_data_sum + xpo_losses = self.alpha * model_logprobs_ref_data_sum # Total loss loss = (dpo_losses + xpo_losses).mean() @@ -307,8 +324,8 @@ def gather_mean(tensor): # Log rewards # Compute various statistics - chosen_rewards = chosen_log_ratios * self.args.beta - rejected_rewards = rejected_log_ratios * self.args.beta + chosen_rewards = chosen_log_ratios * self.beta + rejected_rewards = rejected_log_ratios * self.beta self.stats["rewards/chosen"].append(gather_mean(chosen_rewards.mean())) self.stats["rewards/rejected"].append(gather_mean(rejected_rewards.mean())) @@ -338,10 +355,12 @@ def gather_mean(tensor): self.stats["val/model_contain_eos_token"].append(gather_mean(model_eos.float())) self.stats["val/ref_contain_eos_token"].append(gather_mean(ref_eos.float())) + # Log alpha and beta + self.stats["alpha"].append(self.alpha) + self.stats["beta"].append(self.beta) + def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor: model.train() - ref_model = self.ref_model - ref_model.eval() # need the prompt_ only inputs = self._prepare_inputs(inputs) @@ -353,7 +372,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, del inputs # Sample completions from both the model and the reference model - model_output, ref_output = self._generate_completions(model, ref_model, prompts) + model_output, ref_output = self._generate_completions(prompts, model) # Process model completions model_data, ref_data = self._process_completions(model_output, ref_output, prompts) @@ -363,7 +382,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, # Compute logprobs model_logprobs_model_data, model_logprobs_ref_data, ref_logprobs_ref_data, ref_logprobs_model_data = ( - self._compute_logprobs(model, ref_model, model_data, ref_data, context_length) + self._compute_logprobs(model, model_data, ref_data, context_length) ) # Compute loss