Skip to content

Commit

Permalink
Update SFT examples (#2244)
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun authored Oct 17, 2024
1 parent 494b4af commit a67f214
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 29 deletions.
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,13 +66,17 @@ You can use the TRL Command Line Interface (CLI) to quickly get started with Sup
**SFT:**

```bash
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B --dataset_name trl-lib/Capybara --output_dir Qwen2.5-0.5B-SFT
trl sft --model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name trl-lib/Capybara \
--output_dir Qwen2.5-0.5B-SFT
```

**DPO:**

```bash
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct --dataset_name argilla/Capybara-Preferences --output_dir Qwen2.5-0.5B-DPO
trl dpo --model_name_or_path Qwen/Qwen2.5-0.5B-Instruct \
--dataset_name argilla/Capybara-Preferences \
--output_dir Qwen2.5-0.5B-DPO
```

**Chat:**
Expand Down
56 changes: 29 additions & 27 deletions examples/scripts/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,38 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
# regular:
# Full training
python examples/scripts/sft.py \
--dataset_name trl-lib/ultrafeedback_binarized \
--model_name_or_path="facebook/opt-350m" \
--report_to="wandb" \
--learning_rate=1.41e-5 \
--per_device_train_batch_size=64 \
--gradient_accumulation_steps=16 \
--output_dir="sft_openassistant-guanaco" \
--logging_steps=1 \
--num_train_epochs=3 \
--max_steps=-1 \
--push_to_hub \
--gradient_checkpointing
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--learning_rate 2.0e-5 \
--num_train_epochs 1 \
--packing \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--logging_steps 25 \
--eval_strategy steps \
--eval_steps 100 \
--output_dir Qwen2-0.5B-SFT \
--push_to_hub
# peft:
# LoRA
python examples/scripts/sft.py \
--dataset_name trl-lib/ultrafeedback_binarized \
--model_name_or_path="facebook/opt-350m" \
--report_to="wandb" \
--learning_rate=1.41e-5 \
--per_device_train_batch_size=64 \
--gradient_accumulation_steps=16 \
--output_dir="sft_openassistant-guanaco" \
--logging_steps=1 \
--num_train_epochs=3 \
--max_steps=-1 \
--push_to_hub \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--learning_rate 2.0e-4 \
--num_train_epochs 1 \
--packing \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--logging_steps 25 \
--eval_strategy steps \
--eval_steps 100 \
--use_peft \
--lora_r=64 \
--lora_alpha=16
--lora_r 32 \
--lora_alpha 16 \
--output_dir Qwen2-0.5B-SFT \
--push_to_hub
"""

from datasets import load_dataset
Expand Down
3 changes: 3 additions & 0 deletions trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class SFTConfig(TrainingArguments):
[`ConstantLengthDataset`] based on `dataset_text_field`.
packing (`bool`, *optional*, defaults to `False`):
Controls whether the [`ConstantLengthDataset`] packs the sequences of the dataset.
learning_rate (`float`, *optional*, defaults to `2e-5`):
Initial learning rate for [`AdamW`] optimizer. The default value replaces that of [`~transformers.TrainingArguments`].
max_seq_length (`Optional[int]`, *optional*, defaults to `None`):
Maximum sequence length for the [`ConstantLengthDataset`] and for automatically creating the dataset. If
`None`, it uses the smaller value between `tokenizer.model_max_length` and `1024`.
Expand All @@ -58,6 +60,7 @@ class SFTConfig(TrainingArguments):

dataset_text_field: str = "text"
packing: bool = False
learning_rate: float = 2.0e-5
max_seq_length: Optional[int] = None
dataset_num_proc: Optional[int] = None
dataset_batch_size: int = 1000
Expand Down

0 comments on commit a67f214

Please sign in to comment.