Skip to content

Commit

Permalink
Standardizing datasets for testing (#2065)
Browse files Browse the repository at this point in the history
* zen dataset

* Update dataset test bco

* some tests

* Simple chat template

* bco

* xpo

* kto

* gkd

* trainer_args

* sft

* online dpo

* orpo

* zen script
  • Loading branch information
qgallouedec authored Sep 14, 2024
1 parent f6c6643 commit 40f0522
Show file tree
Hide file tree
Showing 23 changed files with 912 additions and 741 deletions.
4 changes: 3 additions & 1 deletion examples/datasets/tokenize_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from datasets import load_dataset
from transformers import AutoTokenizer, HfArgumentParser

from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


"""
python -i examples/datasets/tokenize_ds.py --model HuggingFaceH4/zephyr-7b-beta
Expand All @@ -41,7 +43,7 @@ class ScriptArguments:
ds = load_dataset(args.dataset)
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\n\n'}}{% endfor %}{{ eos_token }}"
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
Expand Down
583 changes: 583 additions & 0 deletions examples/datasets/zen.py

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion examples/scripts/cpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import CPOConfig, CPOTrainer, ModelConfig, get_peft_config
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


@dataclass
Expand Down Expand Up @@ -90,7 +91,7 @@ class ScriptArguments:
################
ds = load_dataset(args.dataset)
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

def process(row):
row["chosen"] = tokenizer.apply_chat_template(row["chosen"], tokenize=False)
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"""

from trl.commands.cli_utils import DPOScriptArguments, TrlParser
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down Expand Up @@ -102,7 +102,7 @@
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
if args.ignore_bias_buffers:
# torch distributed hack
model._ddp_params_and_buffers_to_ignore = [
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/dpo_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
)

from trl.commands.cli_utils import TrlParser
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

if __name__ == "__main__":
parser = TrlParser((DPOScriptArguments, OnlineDPOConfig, ModelConfig))
Expand Down Expand Up @@ -97,7 +97,7 @@
**model_kwargs,
)
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token

Expand Down
3 changes: 2 additions & 1 deletion examples/scripts/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser

from trl import ModelConfig, ORPOConfig, ORPOTrainer, get_peft_config
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


@dataclass
Expand Down Expand Up @@ -90,7 +91,7 @@ class ScriptArguments:
################
ds = load_dataset(args.dataset)
if tokenizer.chat_template is None:
tokenizer.chat_template = "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

def process(row):
row["prompt"] = tokenizer.apply_chat_template(row["chosen"][:-1], tokenize=False)
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)

from trl import ModelConfig, PPOv2Config, PPOv2Trainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


"""
Expand Down Expand Up @@ -71,7 +71,7 @@
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
value_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
)

from trl import ModelConfig, PPOv2Config, PPOv2Trainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


"""
Expand Down Expand Up @@ -73,7 +73,7 @@
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
value_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/rloo/rloo.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from trl import ModelConfig
from trl.trainer.rloo_trainer import RLOOConfig, RLOOTrainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


"""
Expand Down Expand Up @@ -75,7 +75,7 @@
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
reward_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/rloo/rloo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from trl import ModelConfig
from trl.trainer.rloo_trainer import RLOOConfig, RLOOTrainer
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


"""
Expand Down Expand Up @@ -76,7 +76,7 @@
)
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
reward_model = AutoModelForSequenceClassification.from_pretrained(
config.reward_model_path, trust_remote_code=model_config.trust_remote_code, num_labels=1
)
Expand Down
4 changes: 2 additions & 2 deletions examples/scripts/xpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
LogCompletionsCallback,
)
from trl.commands.cli_utils import TrlParser
from trl.trainer.utils import SIMPLE_QUERY_CHAT_TEMPLATE
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


if __name__ == "__main__":
Expand Down Expand Up @@ -85,7 +85,7 @@
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_QUERY_CHAT_TEMPLATE
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

dataset = load_dataset(args.dataset_name)

Expand Down
Loading

0 comments on commit 40f0522

Please sign in to comment.