generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial skeleton * initial config and class * move TrainerCallback to callbacks.py * initial trainer mockup * formatting * add back header * script with reward model * call ref policy forward with torch no_grad * fix api * clean up the configs * use the new API * fix typo * get get_reward without grads * remove unused no_grad calls * fix formatting * initial GeometricMixtureWrapper * Update trl/models/modeling_base.py Co-authored-by: Alvaro Bartolome <[email protected]> * undo changes to callback * GenerationMixin needs generation_config * calculate score with model and mixture model outputs * fix scores and mixture_scores tensors * undo * use interleaved version to calcuate chosen-rejected * Revert "use interleaved version to calcuate chosen-rejected" This reverts commit 4a63a60. * fix mixture scores * Fix global step * use mixture_coeff * record scores_margin only * fix del * First version of Nash MD trainer * undo * fix formatting * fix toc * initial refactorin * mixin fixes * fix refactoring * cleanup comments * add log_stats * add test * initial docs * fix logs * fix missing_eos_penalty * fix output_dir * add peft_config to docs and super * undo init changes * Update docs/source/_toctree.yml Co-authored-by: Quentin Gallouédec <[email protected]> * Update trl/trainer/nash_md_config.py Co-authored-by: Quentin Gallouédec <[email protected]> * add dataset format * add authors * add dynamic parameter callback * update test * fix comments * test GeometricMixtureWrapper * header * formatting * formatting * add paper and abstract * Update docs/source/nash_md_trainer.md Co-authored-by: Quentin Gallouédec <[email protected]> * DynamicParameterCallback * drop callback in favor of getter * revert kto config change * revert kto config change * fix contribution * `coeff` to `coef` * log dynamic coefs * Update docs/source/nash_md_trainer.md * Update docs/source/nash_md_trainer.md * fix tests * use self.ref_model * one-line --------- Co-authored-by: Alvaro Bartolome <[email protected]> Co-authored-by: Daniil Tiapkin <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]> Co-authored-by: Quentin Gallouédec <[email protected]>
- Loading branch information
1 parent
cdafc93
commit dc2bd07
Showing
18 changed files
with
1,053 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Nash MD Trainer | ||
|
||
## Overview | ||
Nash-MD was proposed in the paper [Nash Learning from Human Feedback](https://huggingface.co/papers/2312.00886) by Rémi Munos, [Michal Valko](https://huggingface.co/misovalko), Daniele Calandriello, Mohammad Gheshlaghi Azar, Mark Rowland, Daniel Guo, Yunhao Tang, Matthieu Geist, Thomas Mésnard, and Andrea Michi. | ||
|
||
The abstract from the paper is the following: | ||
|
||
> Reinforcement learning from human feedback (RLHF) has emerged as the main paradigm for aligning large language models (LLMs) with human preferences. Typically, RLHF involves the initial step of learning a reward model from human feedback, often expressed as preferences between pairs of text generations produced by a pre-trained LLM. Subsequently, the LLM's policy is fine-tuned by optimizing it to maximize the reward model through a reinforcement learning algorithm. However, an inherent limitation of current reward models is their inability to fully represent the richness of human preferences and their dependency on the sampling distribution. In this study, we introduce an alternative pipeline for the fine-tuning of LLMs using pairwise human feedback. Our approach entails the initial learning of a preference model, which is conditioned on two inputs given a prompt, followed by the pursuit of a policy that consistently generates responses preferred over those generated by any competing policy, thus defining the Nash equilibrium of this preference model. We term this approach Nash learning from human feedback (NLHF). In the context of a tabular policy representation, we present a novel algorithmic solution, Nash-MD, founded on the principles of mirror descent. This algorithm produces a sequence of policies, with the last iteration converging to the regularized Nash equilibrium. Additionally, we explore parametric representations of policies and introduce gradient descent algorithms for deep-learning architectures. To demonstrate the effectiveness of our approach, we present experimental results involving the fine-tuning of a LLM for a text summarization task. We believe NLHF offers a compelling avenue for preference learning and policy optimization with the potential of advancing the field of aligning LLMs with human preferences. | ||
|
||
|
||
This post-training method was contributed by [Kashif Rasul](https://huggingface.co/kashif) and [Daniil Tiapkin](https://huggingface.co/dtiapkin), [Pierre Ménard](https://huggingface.co/menardprr), Daniele Calandriello and [Quentin Gallouédec](https://huggingface.co/qgallouedec). | ||
|
||
## Get started | ||
|
||
To just run the Nash MD script to make sure this trainer can run, you can run the following command to train a Nash MD model with a dummy reward model. | ||
|
||
```bash | ||
python examples/scripts/nash_md.py \ | ||
--model_name_or_path EleutherAI/pythia-14m \ | ||
--reward_model_path EleutherAI/pythia-14m \ | ||
--dataset_name trl-lib/tldr \ | ||
--learning_rate 5.0e-7 \ | ||
--output_dir pythia-14m-tldr-nash-md \ | ||
--per_device_train_batch_size 4 \ | ||
--gradient_accumulation_steps 32 \ | ||
--num_train_epochs 3 \ | ||
--max_new_tokens 64 \ | ||
--warmup_ratio 0.1 \ | ||
--missing_eos_penalty 1.0 | ||
``` | ||
|
||
## Explanation of the logged metrics | ||
|
||
The logged metrics are as follows: | ||
|
||
* `loss/score`: The mean reinforce score loss. | ||
* `loss/kl_div`: The mean kl divergence loss. | ||
* `objective/entropy`: The mean entropy of the model and reference data. | ||
* `rewards/accuracies`: The accuracies of the Nash MD's implicit reward model. | ||
* `rewards/chosen`: The mean scores (according to the reward model) of the model completions. | ||
* `rewards/rejected`: The mean scores (according to the reward model) of the mixture completions. | ||
* `rewards/margins`: The mean reward margin (according to reward model) between the chosen and mixture completions. | ||
* `logps/chosen`: The mean log probabilities of the chosen completions. | ||
* `logps/rejected`: The mean log probabilities of the reference completions. | ||
* `val/model_contain_eos_token`: The amount of times the model's output contains the eos token. | ||
* `val/ref_contain_eos_token`: The amount of times the mixture's output contains the eos token. | ||
|
||
## NashMDTrainer | ||
|
||
[[autodoc]] NashMDTrainer | ||
|
||
## NashMDConfig | ||
|
||
[[autodoc]] NashMDConfig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
# flake8: noqa | ||
|
||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Usage: | ||
python examples/scripts/nash_md.py \ | ||
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ | ||
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ | ||
--dataset_name trl-lib/tldr \ | ||
--learning_rate 5.0e-7 \ | ||
--output_dir pythia-1b-tldr-nash-md \ | ||
--per_device_train_batch_size 4 \ | ||
--gradient_accumulation_steps 32 \ | ||
--num_train_epochs 3 \ | ||
--max_new_tokens 64 \ | ||
--warmup_ratio 0.1 \ | ||
--missing_eos_penalty 1.0 \ | ||
--push_to_hub | ||
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \ | ||
examples/scripts/nash_md.py \ | ||
--model_name_or_path trl-lib/pythia-1b-deduped-tldr-sft \ | ||
--reward_model_path trl-lib/pythia-1b-deduped-tldr-rm \ | ||
--dataset_name trl-lib/tldr \ | ||
--learning_rate 5.0e-7 \ | ||
--output_dir pythia-1b-tldr-nash-md \ | ||
--per_device_train_batch_size 4 \ | ||
--gradient_accumulation_steps 32 \ | ||
--num_train_epochs 3 \ | ||
--max_new_tokens 64 \ | ||
--warmup_ratio 0.1 \ | ||
--missing_eos_penalty 1.0 \ | ||
--push_to_hub | ||
""" | ||
|
||
import torch | ||
from datasets import load_dataset | ||
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, GenerationConfig | ||
from accelerate import PartialState | ||
from trl import ( | ||
DPOScriptArguments, | ||
ModelConfig, | ||
NashMDConfig, | ||
NashMDTrainer, | ||
get_kbit_device_map, | ||
get_quantization_config, | ||
maybe_apply_chat_template, | ||
LogCompletionsCallback, | ||
) | ||
from trl.commands.cli_utils import TrlParser | ||
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = TrlParser((DPOScriptArguments, NashMDConfig, ModelConfig)) | ||
args, training_args, model_config = parser.parse_args_and_config() | ||
args.gradient_checkpointing_kwargs = {"use_reentrant": True} | ||
|
||
torch_dtype = ( | ||
model_config.torch_dtype | ||
if model_config.torch_dtype in ["auto", None] | ||
else getattr(torch, model_config.torch_dtype) | ||
) | ||
quantization_config = get_quantization_config(model_config) | ||
model_kwargs = dict( | ||
revision=model_config.model_revision, | ||
attn_implementation=model_config.attn_implementation, | ||
torch_dtype=torch_dtype, | ||
use_cache=False if training_args.gradient_checkpointing else True, | ||
device_map=get_kbit_device_map() if quantization_config is not None else None, | ||
quantization_config=quantization_config, | ||
) | ||
|
||
model = AutoModelForCausalLM.from_pretrained( | ||
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs | ||
) | ||
ref_model = AutoModelForCausalLM.from_pretrained( | ||
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, **model_kwargs | ||
) | ||
reward_model = AutoModelForSequenceClassification.from_pretrained( | ||
training_args.reward_model_path, num_labels=1, trust_remote_code=model_config.trust_remote_code | ||
) | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
model_config.model_name_or_path, | ||
padding_side="left", | ||
trust_remote_code=model_config.trust_remote_code, | ||
) | ||
if tokenizer.pad_token is None: | ||
tokenizer.pad_token = tokenizer.eos_token | ||
if tokenizer.chat_template is None: | ||
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE | ||
|
||
dataset = load_dataset(args.dataset_name) | ||
|
||
with PartialState().local_main_process_first(): | ||
dataset = dataset.map( | ||
maybe_apply_chat_template, num_proc=training_args.dataset_num_proc, fn_kwargs={"tokenizer": tokenizer} | ||
) | ||
|
||
trainer = NashMDTrainer( | ||
model=model, | ||
ref_model=ref_model, | ||
reward_model=reward_model, | ||
args=training_args, | ||
train_dataset=dataset[args.dataset_train_split], | ||
eval_dataset=dataset[args.dataset_test_split], | ||
tokenizer=tokenizer, | ||
) | ||
generation_config = GenerationConfig( | ||
max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature | ||
) | ||
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8) | ||
trainer.add_callback(completions_callback) | ||
# train the model | ||
trainer.train() | ||
|
||
# save the model | ||
trainer.save_model(training_args.output_dir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright 2024 The HuggingFace Team. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import unittest | ||
|
||
import torch | ||
from transformers import AutoModelForCausalLM, GenerationConfig | ||
|
||
from trl.models.modeling_base import GeometricMixtureWrapper, create_reference_model | ||
|
||
|
||
class TestGeometricMixtureWrapper(unittest.TestCase): | ||
def setUp(self): | ||
self.model = AutoModelForCausalLM.from_pretrained("gpt2") | ||
self.ref_model = create_reference_model(self.model) | ||
self.generation_config = GenerationConfig.from_pretrained("gpt2") | ||
self.mixture_coef = 0.5 | ||
self.wrapper = GeometricMixtureWrapper( | ||
self.model, self.ref_model, self.generation_config, mixture_coef=self.mixture_coef | ||
) | ||
|
||
def test_forward(self): | ||
input_ids = torch.tensor([[1, 2, 3, 4, 5]]) | ||
attention_mask = torch.ones_like(input_ids) | ||
|
||
output = self.wrapper(input_ids=input_ids, attention_mask=attention_mask) | ||
|
||
self.assertIsNotNone(output) | ||
self.assertTrue(hasattr(output, "logits")) | ||
self.assertEqual(output.logits.shape, (1, 5, self.model.config.vocab_size)) | ||
|
||
def test_mixture_coefficient(self): | ||
input_ids = torch.tensor([[1, 2, 3, 4, 5]]) | ||
attention_mask = torch.ones_like(input_ids) | ||
|
||
with torch.no_grad(): | ||
model_output = self.model(input_ids=input_ids, attention_mask=attention_mask) | ||
ref_model_output = self.ref_model(input_ids=input_ids, attention_mask=attention_mask) | ||
wrapper_output = self.wrapper(input_ids=input_ids, attention_mask=attention_mask) | ||
|
||
expected_logits = torch.nn.functional.log_softmax( | ||
self.mixture_coef * ref_model_output.logits + (1 - self.mixture_coef) * model_output.logits, dim=-1 | ||
) | ||
|
||
self.assertTrue(torch.allclose(wrapper_output.logits, expected_logits, atol=1e-5)) | ||
|
||
def test_prepare_inputs_for_generation(self): | ||
input_ids = torch.tensor([[1, 2, 3, 4, 5]]) | ||
attention_mask = torch.ones_like(input_ids) | ||
|
||
inputs = self.wrapper.prepare_inputs_for_generation(input_ids, attention_mask=attention_mask, use_cache=True) | ||
|
||
self.assertIn("input_ids", inputs) | ||
self.assertIn("attention_mask", inputs) | ||
self.assertFalse(inputs.get("use_cache", False)) |
Oops, something went wrong.