Skip to content

Commit

Permalink
ds, raw_dataset etc -> dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
qgallouedec committed Sep 19, 2024
1 parent 0d2bee5 commit ca57ef2
Show file tree
Hide file tree
Showing 17 changed files with 72 additions and 95 deletions.
4 changes: 2 additions & 2 deletions docs/source/detoxifying_a_lm.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,13 @@ And its `continuation` value:

We want to increase the chance for the model to generate toxic prompts so we get more learning signal. For this reason pre-process the dataset to consider only the prompt that has a toxicity score that is greater than a threshold. We can do this in a few lines of code:
```python
ds = load_dataset("allenai/real-toxicity-prompts", split="train")
train_dataset = load_dataset("allenai/real-toxicity-prompts", split="train")

def filter_fn(sample):
toxicity = sample["prompt"]["toxicity"]
return toxicity is not None and toxicity > 0.3

ds = ds.filter(filter_fn, batched=False)
train_dataset = dataset.filter(filter_fn, batched=False)
```

### Reward function
Expand Down
8 changes: 4 additions & 4 deletions examples/datasets/tokenize_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

@dataclass
class ScriptArguments:
dataset: str = field(
dataset_name: str = field(
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style", metadata={"help": "The dataset to load"}
)
model: str = field(default="gpt2", metadata={"help": "The model to use for tokenization"})
Expand All @@ -40,7 +40,7 @@ class ScriptArguments:

if __name__ == "__main__":
args = HfArgumentParser(ScriptArguments).parse_args_into_dataclasses()[0]
ds = load_dataset(args.dataset)
dataset = load_dataset(args.dataset_name)
tokenizer = AutoTokenizer.from_pretrained(args.model)
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
Expand All @@ -50,5 +50,5 @@ def process(row):
row["rejected"] = tokenizer.apply_chat_template(row["rejected"], tokenize=False)
return row

ds = ds.map(process, num_proc=args.dataset_num_proc)
print(ds["train"][0]["chosen"])
dataset = dataset.map(process, num_proc=args.dataset_num_proc)
print(dataset["train"][0]["chosen"])
24 changes: 12 additions & 12 deletions examples/scripts/bco.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,38 +114,38 @@ def get_model_response(example, llm_name: str):

dataset = load_dataset("openbmb/UltraFeedback")["train"]

ds = dataset.filter(lambda example: llm_name in example["models"], batched=False, num_proc=num_proc)
ds = ds.filter(
dataset = dataset.filter(lambda example: llm_name in example["models"], batched=False, num_proc=num_proc)
dataset = dataset.filter(
lambda example: len(example["models"]) == len(example["completions"]), batched=False, num_proc=num_proc
)

METRIC = "helpfulness"

ds = ds.map(
dataset = dataset.map(
get_model_rating,
batched=False,
fn_kwargs={"metric": METRIC, "llm_name": llm_name},
num_proc=num_proc,
)

ds = ds.map(
dataset = dataset.map(
get_model_response,
batched=False,
fn_kwargs={"llm_name": llm_name},
num_proc=num_proc,
)

ds = ds.select_columns(["source", "instruction", "response", "helpfulness"])
dataset = dataset.select_columns(["source", "instruction", "response", "helpfulness"])

ds = ds.rename_columns({"instruction": "prompt", "response": "completion"})
ds = ds.map(lambda example: {"label": example["helpfulness"] >= 5}, batched=False, num_proc=num_proc)
dataset = dataset.rename_columns({"instruction": "prompt", "response": "completion"})
dataset = dataset.map(lambda example: {"label": example["helpfulness"] >= 5}, batched=False, num_proc=num_proc)

ds = ds.map(
dataset = dataset.map(
lambda example: {"prompt": [{"role": "user", "content": example["prompt"]}]},
batched=False,
num_proc=num_proc,
)
dataset = ds.train_test_split(test_size=0.05, seed=42)
dataset = dataset.train_test_split(test_size=0.05, seed=42)

return dataset

Expand Down Expand Up @@ -209,7 +209,7 @@ def format_dataset(example):
with PartialState().local_main_process_first():
# Load the dataset
dataset = build_helpfulness_dataset(script_args.llm_name, num_proc=bco_args.dataset_num_proc)
formatted_dataset = dataset.map(format_dataset, batched=False, num_proc=bco_args.dataset_num_proc)
dataset = dataset.map(format_dataset, batched=False, num_proc=bco_args.dataset_num_proc)

accelerator = Accelerator()
embedding_model = AutoModel.from_pretrained(
Expand All @@ -233,8 +233,8 @@ def format_dataset(example):
model,
ref_model,
args=bco_args,
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["test"],
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
embedding_func=embedding_func,
Expand Down
13 changes: 5 additions & 8 deletions examples/scripts/cpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

@dataclass
class ScriptArguments:
dataset: str = field(
dataset_name: str = field(
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style",
metadata={"help": "The name of the dataset to use."},
)
Expand All @@ -89,7 +89,7 @@ class ScriptArguments:
################
# Dataset
################
ds = load_dataset(args.dataset)
dataset = load_dataset(args.dataset_name)
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

Expand All @@ -101,19 +101,16 @@ def process(row):
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=cpo_args.dataset_num_proc)

train_dataset = ds["train"]
eval_dataset = ds["test"]
dataset = dataset.map(process, num_proc=cpo_args.dataset_num_proc)

################
# Training
################
trainer = CPOTrainer(
model,
args=cpo_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)
Expand Down
13 changes: 5 additions & 8 deletions examples/scripts/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,26 +112,23 @@
################
# Dataset
################
ds = load_dataset(args.dataset_name)
dataset = load_dataset(args.dataset_name)

with PartialState().local_main_process_first():
ds = ds.map(maybe_extract_prompt, num_proc=training_args.dataset_num_proc)
ds = ds.map(
dataset = dataset.map(maybe_extract_prompt, num_proc=training_args.dataset_num_proc)
dataset = dataset.map(
maybe_apply_chat_template, num_proc=training_args.dataset_num_proc, fn_kwargs={"tokenizer": tokenizer}
)

train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]

##########
# Training
################
trainer = DPOTrainer(
model,
ref_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
train_dataset=dataset[args.dataset_train_split],
eval_dataset=dataset[args.dataset_test_split],
tokenizer=tokenizer,
peft_config=peft_config,
)
Expand Down
11 changes: 4 additions & 7 deletions examples/scripts/dpo_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
################
# Dataset
################
ds = load_dataset(args.dataset_name)
dataset = load_dataset(args.dataset_name)

def process(row):
row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False)
Expand All @@ -116,10 +116,7 @@ def process(row):
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_proc=training_args.dataset_num_proc)

train_dataset = ds[args.dataset_train_split]
eval_dataset = ds[args.dataset_test_split]
dataset = dataset.map(process, num_proc=training_args.dataset_num_proc)

################
# Training
Expand All @@ -128,8 +125,8 @@ def process(row):
model,
ref_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
train_dataset=dataset[args.dataset_train_split],
eval_dataset=dataset[args.dataset_test_split],
tokenizer=processor,
peft_config=peft_config,
)
Expand Down
8 changes: 4 additions & 4 deletions examples/scripts/evals/judge_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ class ScriptArguments:
args = parser.parse_args_into_dataclasses()[0]

# Load the dataset
raw_dataset = load_dataset("trl-lib/tldr", split="validation")
dataset = load_dataset("trl-lib/tldr", split="validation")
if args.num_examples is not None:
raw_dataset = raw_dataset.select(range(args.num_examples))
dataset = dataset.select(range(args.num_examples))

# Extract the prompts and reference completions
prompts = raw_dataset["prompt"]
reference_completions = raw_dataset["completion"]
prompts = dataset["prompt"]
reference_completions = dataset["completion"]

# Generate the model completions
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=200) # very generous max token length
Expand Down
11 changes: 4 additions & 7 deletions examples/scripts/gkd.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,28 +103,25 @@
################
# Dataset
################
raw_datasets = load_dataset(args.dataset_name)
dataset = load_dataset(args.dataset_name)

with PartialState().local_main_process_first():
raw_datasets = raw_datasets.map(
dataset = dataset.map(
lambda x: {
"prompt": tokenizer.apply_chat_template(x["prompt"], tokenize=False, add_generation_prompt=True)
},
num_proc=training_args.dataset_num_proc,
)

train_dataset = raw_datasets[args.dataset_train_split]
eval_dataset = raw_datasets[args.dataset_test_split]

################
# Training
################
trainer = GKDTrainer(
model=model_config.model_name_or_path,
teacher_model=training_args.teacher_model_name_or_path,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
train_dataset=dataset[args.dataset_train_split],
eval_dataset=dataset[args.dataset_test_split],
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)
Expand Down
6 changes: 3 additions & 3 deletions examples/scripts/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ def format_dataset(example):
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
formatted_dataset = dataset.map(format_dataset, num_proc=kto_args.dataset_num_proc)
dataset = dataset.map(format_dataset, num_proc=kto_args.dataset_num_proc)

# Initialize the KTO trainer
kto_trainer = KTOTrainer(
model,
ref_model,
args=kto_args,
train_dataset=formatted_dataset["train"],
eval_dataset=formatted_dataset["test"],
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
peft_config=get_peft_config(model_args),
)
Expand Down
13 changes: 5 additions & 8 deletions examples/scripts/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

@dataclass
class ScriptArguments:
dataset: str = field(
dataset_name: str = field(
default="trl-internal-testing/hh-rlhf-helpful-base-trl-style",
metadata={"help": "The name of the dataset to use."},
)
Expand All @@ -89,7 +89,7 @@ class ScriptArguments:
################
# Dataset
################
ds = load_dataset(args.dataset)
dataset = load_dataset(args.dataset_name)
if tokenizer.chat_template is None:
tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE

Expand All @@ -102,19 +102,16 @@ def process(row):
# Compute that only on the main process for faster data processing.
# see: https://github.com/huggingface/trl/pull/1255
with PartialState().local_main_process_first():
ds = ds.map(process, num_prc=orpo_args.dataset_num_proc)

train_dataset = ds["train"]
eval_dataset = ds["test"]
dataset = dataset.map(process, num_prc=orpo_args.dataset_num_proc)

################
# Training
################
trainer = ORPOTrainer(
model,
args=orpo_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
tokenizer=tokenizer,
peft_config=get_peft_config(model_config),
)
Expand Down
6 changes: 3 additions & 3 deletions examples/scripts/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,10 @@
################
# Dataset
################
raw_datasets = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")
dataset = load_dataset("trl-internal-testing/descriptiveness-sentiment-trl-style", split="descriptiveness")
eval_samples = 20
train_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples))
eval_dataset = raw_datasets.select(range(len(raw_datasets) - eval_samples, len(raw_datasets)))
train_dataset = dataset.select(range(len(dataset) - eval_samples))
eval_dataset = dataset.select(range(len(dataset) - eval_samples, len(dataset)))
dataset_text_field = "prompt"

def prepare_dataset(dataset, tokenizer):
Expand Down
6 changes: 3 additions & 3 deletions examples/scripts/ppo/ppo_tldr.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@
################
# Dataset
################
raw_datasets = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style")
train_dataset = raw_datasets["train"]
eval_dataset = raw_datasets["validation"]
dataset = load_dataset("trl-internal-testing/tldr-preference-sft-trl-style")
train_dataset = dataset["train"]
eval_dataset = dataset["validation"]

def prepare_dataset(dataset, tokenizer):
"""pre-tokenize the dataset before training; only collate during training"""
Expand Down
Loading

0 comments on commit ca57ef2

Please sign in to comment.