Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions docs/source/sft_trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,69 @@ trainer.train()

<iframe src="https://trl-lib-trackio.hf.space/?project=trl-documentation&metrics=train*&runs=sft_qwen3-0.6B_capybara" style="width: 100%; min-width: 300px; max-width: 800px;" height="830" frameBorder="0"></iframe>

## Using SFTTrainer with Accelerate

`SFTTrainer` is built on top of the Transformers `Trainer` and relies on
Accelerate for device placement and distributed training.

When training with `SFTTrainer`, models should be loaded normally using
`from_pretrained`, and training scripts should be launched with
`accelerate launch`. Device placement, multi-GPU execution, and
distributed strategies are handled by Accelerate at runtime.

The `device_map="auto"` option is intended for inference-time model sharding
and should not be used during training, as it can conflict with
Accelerate-managed execution and lead to runtime errors.

### Minimal example

```bash
accelerate launch train_sft.py
```
> ⚠️ **Important**
>
> Do not pass `device_map="auto"` to `from_pretrained()` when using
> `SFTTrainer`, as this option is intended for inference and can conflict
> with Accelerate-managed training.

```python
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from trl import SFTTrainer
from datasets import Dataset

model_name = "facebook/opt-125m"

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(model_name)

dataset = Dataset.from_dict(
{
"text": [
"Explain supervised fine-tuning.",
"What is tensor parallelism?",
]
}
)

training_args = TrainingArguments(
output_dir="./sft_output",
per_device_train_batch_size=1,
num_train_epochs=1,
logging_steps=1,
)

trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
tokenizer=tokenizer,
)

trainer.train()
```

## Expected dataset type and format

SFT supports both [language modeling](dataset_formats#language-modeling) and [prompt-completion](dataset_formats#prompt-completion) datasets. The [`SFTTrainer`] is compatible with both [standard](dataset_formats#standard) and [conversational](dataset_formats#conversational) dataset formats. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Expand Down