Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/paper_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -861,6 +861,30 @@ SFTConfig(
)
```

### Entropy-Adaptive Fine-Tuning (EAFT)

**📜 Paper**: https://huggingface.co/papers/2601.02151

EAFT introduces an entropy-based gating mechanism to the standard Cross-Entropy loss:


$$\mathcal{L}_{EAFT} (\theta) = - \sum_{t=1}^{T} \tilde{H}_t \cdot \log P_\theta(y_t | x, y_{<t})$$


Where $\tilde{H}_t$ is the normalized entropy. This mechanism addresses catastrophic forgetting in supervised fine-tuning by using token-level entropy to distinguish uncertainty from knowledge conflict, enabling better preservation of general capabilities.

The eaft_alpha parameter controls how strongly the loss is weighted based on entropy.

```python
from trl import SFTConfig

training_args = SFTConfig(
loss_type="eaft",
eaft_alpha = 1.0, # default
...
)
```

## Parameter-Efficient Fine-Tuning (PEFT)

For general details on using PEFT with TRL, please refer to the [PEFT Integration](peft_integration) guide.
Expand Down
105 changes: 105 additions & 0 deletions tests/test_eaft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@

import pytest
import torch
import torch.nn as nn
from datasets import load_dataset
from transformers import AutoModelForCausalLM
from unittest.mock import MagicMock

from trl import SFTConfig, SFTTrainer
from .testing_utils import TrlTestCase
from trl.trainer.sft_trainer import eaft_loss_func


class TestEAFTLoss(TrlTestCase):
def test_eaft_loss(self):
batch_size = 2
seq_len = 3
vocab_size = 25

# Create random logits and labels
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))

# Use a dict for outputs to behave like the model output dict
outputs = {"logits": logits}

# Calculate loss
loss = eaft_loss_func(outputs, labels, alpha=1.0)

# Simple assertions
assert torch.is_tensor(loss)
assert loss.dim() == 0

def test_eaft_loss_zero_alpha(self):
batch_size = 2
seq_len = 3
vocab_size = 25
logits = torch.randn(batch_size, seq_len, vocab_size)
labels = torch.randint(0, vocab_size, (batch_size, seq_len))

# ensure ignore_index handling matches
labels[0, 0] = -100

outputs = {"logits": logits}

# EAFT with alpha=0
eaft_loss = eaft_loss_func(outputs, labels, alpha=0.0)

# manually replicate the padding and shifting logic from `eaft_loss_func`
# in sft_trainer.py because eaft_loss_func performs this
# internally before computing the loss. To verify alpha=0.0 matches standard CE,
# we apply the same transformations to the labels here
labels_padded = torch.nn.functional.pad(labels, (0, 1), value=-100)
shift_labels = labels_padded[..., 1:].contiguous()
flat_logits = logits.view(-1, vocab_size)
flat_labels = shift_labels.view(-1)

# standard CE loss check
ce_loss = torch.nn.functional.cross_entropy(flat_logits, flat_labels, ignore_index=-100)

torch.testing.assert_close(eaft_loss, ce_loss)


class TestSFTTrainerEAFT(TrlTestCase):
def setup_method(self):
self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
self.dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train[:100]")
self.model = AutoModelForCausalLM.from_pretrained(self.model_id)

def test_train_eaft_loss(self):
training_args = SFTConfig(
output_dir=self.tmp_dir,
loss_type="eaft",
eaft_alpha=0.5,
learning_rate=1e-3,
report_to="none",
max_steps=3,
)
trainer = SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dataset,
)

trainer.train()
# check that loss is recorded
assert trainer.state.log_history[-1]["train_loss"] is not None
# check that it ran 3 steps
assert trainer.state.global_step == 3

def test_train_eaft_init_error(self):
# should raise error if compute_loss_func is provided with loss_type="eaft"
training_args = SFTConfig(
output_dir=self.tmp_dir,
loss_type="eaft",
report_to="none",
)

with pytest.raises(ValueError, match="compute_loss_func"):
SFTTrainer(
model=self.model,
args=training_args,
train_dataset=self.dataset,
compute_loss_func=lambda x, y: 0.0
)
10 changes: 8 additions & 2 deletions trl/trainer/sft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,17 @@ class SFTConfig(TrainingArguments):
default="nll",
metadata={
"help": (
'Type of loss to use. Possible values are `"nll"` (negative log-likelihood, default) and `"dft"` '
"(Dynamic Fine-Tuning, as described in https://huggingface.co/papers/2508.05629)."
'Type of loss to use. Possible values are:\n'
'- `"nll"`: Negative Log-Likelihood (default)\n'
'- `"dft"`: Dynamic Fine-Tuning, as described in [this paper](https://huggingface.co/papers/2508.05629)\n'
'- `"eaft"`: Entropy-Adaptive Fine-Tuning, as described in [this paper](https://huggingface.co/papers/2601.02151)'
)
},
)
eaft_alpha: float = field(
default=1.0,
metadata={"help": "The alpha parameter for EAFT loss to control the power of adaptive weight."},
)
activation_offloading: bool = field(
default=False,
metadata={"help": "Whether to offload the activations to the CPU."},
Expand Down
77 changes: 75 additions & 2 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from typing import Any
from typing import Any, Optional

import torch
import torch.nn as nn
Expand Down Expand Up @@ -482,6 +482,69 @@ def dft_loss(outputs, labels, num_items_in_batch=None):
loss = (per_token_loss * loss_mask).sum() / num_items_in_batch
return loss

def eaft_loss_func(outputs, labels, num_items_in_batch=None, alpha=1.0):
"""
EAFT loss function, as presented in [Entropy-Adaptive Fine-Tuning](https://huggingface.co/papers/2601.02151)
from https://github.com/ymxyll/LlamaFactory-EAFT/blob/feature/eaft/src/llamafactory/train/trainer_utils.py#L682
by ymxyll, 2026, Apache-2.0 license
"""
logits = outputs.get("logits")
if logits is None:
return outputs.get("loss", torch.tensor(0.0))

logits = logits.float()
vocab_size = logits.size(-1)
labels = nn.functional.pad(labels, (0, 1), value=-100)
shift_labels = labels[..., 1:].contiguous()
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(logits.device)

loss = _eaft_cross_entropy(logits, shift_labels, num_items_in_batch, alpha)
return loss

def _eaft_cross_entropy(
source: torch.Tensor,
target: torch.Tensor,
num_items_in_batch: Optional[torch.Tensor] = None,
alpha: float = 1.0,
ignore_index: int = -100,
) -> torch.Tensor:
"""
EAFT cross-entropy loss function, as presented in [Entropy-Adaptive Fine-Tuning](https://huggingface.co/papers/2601.02151)
from https://github.com/ymxyll/LlamaFactory-EAFT/blob/feature/eaft/src/llamafactory/train/trainer_utils.py#L699
by ymxyll, 2026, Apache-2.0 license
"""
per_token_loss = nn.functional.cross_entropy(source, target, ignore_index=ignore_index, reduction="none")
valid_mask = target != ignore_index
if not valid_mask.any():
return torch.tensor(0.0, device=source.device, dtype=source.dtype)

valid_losses = per_token_loss[valid_mask]

with torch.no_grad():
source_detached = source[valid_mask].detach()

topk_val, _ = torch.topk(source_detached, k=20, dim=-1)
logsumexp_topk = torch.logsumexp(topk_val, dim=-1, keepdim=True)
log_probs_topk = topk_val - logsumexp_topk
probs_topk = torch.exp(log_probs_topk)
entropy_approx = -(probs_topk * log_probs_topk).sum(dim=-1)

entropy_term = entropy_approx / 3.0
adaptive_weight = torch.pow(entropy_term, alpha)

weighted_losses = valid_losses * adaptive_weight

if num_items_in_batch is not None:
total_loss = weighted_losses.sum()
if torch.is_tensor(num_items_in_batch):
num_items_in_batch = num_items_in_batch.to(total_loss.device)
loss = total_loss / num_items_in_batch
else:
loss = weighted_losses.mean()
return loss


class SFTTrainer(BaseTrainer):
"""
Expand Down Expand Up @@ -877,8 +940,18 @@ def __init__(
"passing a `compute_loss_func` is not allowed."
)
compute_loss_func = dft_loss
elif args.loss_type == "eaft":
if compute_loss_func is not None:
raise ValueError(
"You passed a `compute_loss_func` together with `loss_type='eaft'` to the `SFTTrainer`. "
"When using `loss_type='eaft'`, the loss function is internally set to the EAFT loss, so "
"passing a `compute_loss_func` is not allowed."
)
compute_loss_func = lambda outputs, labels, num_items_in_batch=None: eaft_loss_func(
outputs, labels, num_items_in_batch, args.eaft_alpha
)
else:
raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll' and 'dft'.")
raise ValueError(f"Invalid `loss_type` {args.loss_type} passed. Supported values are 'nll', 'dft' and 'eaft'.")

super().__init__(
model=model,
Expand Down