Skip to content

Commit

Permalink
Add default Optim to DPO example (#759)
Browse files Browse the repository at this point in the history
* add optim

* make configurable
  • Loading branch information
Nathan Lambert authored Sep 19, 2023
1 parent 5d30cd4 commit d603e7c
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions examples/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ class ScriptArguments:
},
)

# optimizer settings
warmup_steps: Optional[int] = field(default=150, metadata={"help": "Number of warmup steps for optimizer"})
optim: Optional[str] = field(
default="RMSprop",
metadata={"help": "Optimizer to use. Default is RMSprop, if none" "passed defaults to Transformers trainer."},
)


def extract_anthropic_prompt(prompt_and_response):
"""Extract the anthropic prompt from a prompt and response pair."""
Expand Down Expand Up @@ -131,6 +138,19 @@ def split_prompt_and_responses(sample) -> Dict[str, str]:
eval_dataset = get_hh("test", sanity_check=script_args.sanity_check)

# 4. initialize training arguments:

warmup_steps = script_args.warmup_steps
if script_args.optim == "RMSprop": # Trainer to match original paper
optimizer = torch.optim.RMSprop(model.parameters(), lr=script_args.learning_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda step: min(1.0, (step + 1) / (warmup_steps + 1))
)
optim = None
else:
optimizer = None
scheduler = None
optim = script_args.optim

training_args = TrainingArguments(
per_device_train_batch_size=script_args.per_device_train_batch_size,
max_steps=script_args.max_steps,
Expand All @@ -143,6 +163,8 @@ def split_prompt_and_responses(sample) -> Dict[str, str]:
eval_steps=500,
output_dir="./test",
report_to=script_args.report_to,
optim=optim,
warmup_steps=warmup_steps,
)

# 5. initialize the DPO trainer
Expand All @@ -157,6 +179,7 @@ def split_prompt_and_responses(sample) -> Dict[str, str]:
max_length=script_args.max_length,
max_target_length=script_args.max_target_length,
max_prompt_length=script_args.max_prompt_length,
optimizers=(optimizer, scheduler),
)

# 6. train
Expand Down

1 comment on commit d603e7c

@puyuanOT
Copy link

@puyuanOT puyuanOT commented on d603e7c Sep 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this commit will cause the error as the optim will become "None" when using RMSprop:

ValueError: None is not a valid OptimizerNames, please select one of ['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'a

Please sign in to comment.