generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Open
Labels
Description
Reproduction
Description
When using GKD or MiniLLM trainers for on-policy knowledge distillation with student and teacher models that have different tokenizers, the training produces incorrect results. The teacher model receives student-tokenized input without re-tokenization, causing teacher logprobs to be computed in low-probability regions.
Original Source
This issue was brought to light by @HeMuyu0327:
https://x.com/HeMuyu0327/status/1987382662328168905
Problem Details
What Happens
In the current implementation:
- Student model generates rollouts using its own tokenizer (e.g., Qwen tokenizer)
- These student-tokenized IDs are passed directly to the teacher model
- Teacher model (e.g., Llama) interprets these IDs using its own vocabulary mapping
- Since the vocabularies differ, the teacher sees completely different tokens than intended
- Teacher logprobs are computed on nonsensical sequences → incorrect training signal
Example
Consider this concrete example:
Student (Qwen) tokenizes "Hello" → token ID 123
- In Qwen's vocabulary: ID 123 = "Hello" ✓
- In Llama's vocabulary: ID 123 = "World" ✗
Teacher (Llama) receives ID 123 and computes logprobs for "World" instead of "Hello"
→ Wrong probability distribution!
Reproduction
Minimal Example
from trl.experimental.gkd import GKDConfig, GKDTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
# Load models with DIFFERENT tokenizers
student_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
student_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")
config = GKDConfig(output_dir="./output", max_steps=10)
trainer = GKDTrainer(
model=student_model,
teacher_model=teacher_model,
args=config,
train_dataset=dataset,
processing_class=student_tokenizer,
)
# This will produce incorrect results because:
# - Student generates with Qwen tokenizer
# - Teacher receives Qwen-tokenized IDs
# - Teacher interprets IDs using Llama vocabulary
# - Teacher logprobs computed on wrong tokens!
trainer.train()Expected Behavior
Training should work correctly with different tokenizers by:
- Preserving the text from student-generated rollouts
- Re-tokenizing the text using the teacher's tokenizer
- Computing teacher logprobs on correctly tokenized sequences
Actual Behavior
Training either:
- Produces incorrect results (teacher logprobs computed on wrong tokens)
- Fails with cryptic errors
- "Succeeds" but with degraded distillation quality
System Info
Environment
- TRL version:
mainbranch (commit: 9bc6206) - Transformers version: Latest
- Python version: 3.9+
- Affected components:
GKDTrainer,MiniLLMTrainer