Skip to content

Cross-tokenizer distillation fails in GKD and MiniLLM trainers #4562

@sambhavnoobcoder

Description

@sambhavnoobcoder

Reproduction

Description

When using GKD or MiniLLM trainers for on-policy knowledge distillation with student and teacher models that have different tokenizers, the training produces incorrect results. The teacher model receives student-tokenized input without re-tokenization, causing teacher logprobs to be computed in low-probability regions.

Original Source

This issue was brought to light by @HeMuyu0327:
https://x.com/HeMuyu0327/status/1987382662328168905

Problem Details

What Happens

In the current implementation:

  1. Student model generates rollouts using its own tokenizer (e.g., Qwen tokenizer)
  2. These student-tokenized IDs are passed directly to the teacher model
  3. Teacher model (e.g., Llama) interprets these IDs using its own vocabulary mapping
  4. Since the vocabularies differ, the teacher sees completely different tokens than intended
  5. Teacher logprobs are computed on nonsensical sequences → incorrect training signal

Example

Consider this concrete example:

Student (Qwen) tokenizes "Hello" → token ID 123
- In Qwen's vocabulary: ID 123 = "Hello" ✓
- In Llama's vocabulary: ID 123 = "World" ✗

Teacher (Llama) receives ID 123 and computes logprobs for "World" instead of "Hello"
→ Wrong probability distribution!

Reproduction

Minimal Example

from trl.experimental.gkd import GKDConfig, GKDTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

# Load models with DIFFERENT tokenizers
student_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
teacher_model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
student_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")

dataset = load_dataset("trl-internal-testing/zen", "conversational_language_modeling", split="train")

config = GKDConfig(output_dir="./output", max_steps=10)

trainer = GKDTrainer(
    model=student_model,
    teacher_model=teacher_model,
    args=config,
    train_dataset=dataset,
    processing_class=student_tokenizer,
)

# This will produce incorrect results because:
# - Student generates with Qwen tokenizer
# - Teacher receives Qwen-tokenized IDs
# - Teacher interprets IDs using Llama vocabulary
# - Teacher logprobs computed on wrong tokens!
trainer.train()

Expected Behavior

Training should work correctly with different tokenizers by:

  1. Preserving the text from student-generated rollouts
  2. Re-tokenizing the text using the teacher's tokenizer
  3. Computing teacher logprobs on correctly tokenized sequences

Actual Behavior

Training either:

  • Produces incorrect results (teacher logprobs computed on wrong tokens)
  • Fails with cryptic errors
  • "Succeeds" but with degraded distillation quality

System Info

Environment

  • TRL version: main branch (commit: 9bc6206)
  • Transformers version: Latest
  • Python version: 3.9+
  • Affected components: GKDTrainer, MiniLLMTrainer

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 GKDRelated to GKD🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions