Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions TEST_HADW_README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Testing GRPO with HA-DW

This directory contains a test script to verify the HA-DW implementation on your local machine.

## Requirements

Make sure you have the necessary dependencies installed:

```bash
pip install torch transformers datasets accelerate
```

## Running the Test

### Test with HA-DW enabled (default)

```bash
python test_hadw_grpo.py
```

This will:
- Load a small Qwen 2.5 0.5B model
- Create a synthetic math dataset (32 simple addition problems)
- Train with GRPO + HA-DW for 1 epoch
- Display HA-DW metrics during training

### Test baseline GRPO (for comparison)

```bash
python test_hadw_grpo.py --no-hadw
```

This runs the same test but with HA-DW disabled, allowing you to compare the training dynamics.

## What to Expect

The script will:

1. ✅ Check MPS availability
2. ✅ Load the model (Qwen2.5-0.5B-Instruct)
3. ✅ Create a small synthetic dataset
4. ✅ Initialize the GRPO trainer with HA-DW
5. ✅ Run training for 1 epoch
6. ✅ Display HA-DW metrics:
- `hadw/capability_prior`: Model's evolving capability estimate
- `hadw/capability_posterior`: Updated capability after each batch
- `hadw/batch_accuracy`: Accuracy on current batch
- `hadw/eta_t`: Adaptive forgetting factor
- `hadw/reweighting_mean`: Average reweighting factor
- `hadw/reweighting_std`: Std dev of reweighting factors

## Expected Output

You should see output like:

```
================================================================================
Testing GRPO with HA-DW on MPS
================================================================================
✓ MPS is available

📦 Loading model: Qwen/Qwen2.5-0.5B-Instruct
Device: mps
Using dtype: torch.float16

📊 Creating synthetic dataset...
Dataset size: 32 samples
Example prompt: What is 0 + 3? Answer with just the number.
Example answer: 3

⚙️ Configuring GRPO with HA-DW...
✓ HA-DW enabled: True
✓ Eta: 0.1
✓ Lambda scale: 1.0
✓ History window: 5
✓ Num generations: 4

🚀 Initializing GRPO Trainer...
✓ Trainer initialized successfully

🏋️ Starting training...
[Training logs...]

📈 HA-DW Metrics:
hadw/capability_prior:
- First: 0.2500
- Last: 0.3125
- Mean: 0.2812
hadw/capability_posterior:
- First: 0.2625
- Last: 0.3250
- Mean: 0.2937
[...]

================================================================================
✅ Test completed successfully!
================================================================================
```

## Troubleshooting

### MPS not available
If you see "MPS not available", the script will fall back to CPU. This is normal on non-Apple Silicon machines.

### Out of memory
If you run out of memory, try:
- Reducing `per_device_train_batch_size` in the script (currently 2)
- Reducing `num_generations` (currently 4)
- Using an even smaller dataset

### Model download issues
The first run will download the model (~500MB). Ensure you have:
- Internet connection
- Sufficient disk space
- HuggingFace access (no token needed for this public model)

## Comparing with and without HA-DW

To see the effect of HA-DW, run both versions and compare the metrics:

```bash
# With HA-DW
python test_hadw_grpo.py > with_hadw.log 2>&1

# Without HA-DW
python test_hadw_grpo.py --no-hadw > without_hadw.log 2>&1

# Compare
diff with_hadw.log without_hadw.log
```

You should observe that HA-DW:
- Adjusts advantages based on prompt difficulty
- Tracks model capability evolution across batches
- Applies adaptive reweighting to correct bias
193 changes: 193 additions & 0 deletions test_hadw_grpo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
#!/usr/bin/env python3
"""
Test script for GRPO with HA-DW on MPS.

This script trains a small model on a synthetic math dataset to verify
that the HA-DW implementation works correctly.

Note: BitsAndBytes quantization is not well-supported on MPS, so this
script uses the full precision model instead.
"""

import torch
from datasets import Dataset
from transformers import AutoTokenizer
from trl import GRPOConfig, GRPOTrainer


def create_synthetic_math_dataset(num_samples=50):
"""Create a small synthetic math dataset for testing."""
prompts = []

# Simple addition problems
for i in range(num_samples):
a = i % 10
b = (i + 3) % 10
prompts.append({
"prompt": f"What is {a} + {b}? Answer with just the number.",
"answer": str(a + b)
})

return Dataset.from_list(prompts)


def accuracy_reward(prompts, completions, answer, **kwargs):
"""
Simple reward function that checks if the completion contains the correct answer.

Returns:
List of rewards (1.0 for correct, 0.0 for incorrect)
"""
rewards = []
for completion, correct_answer in zip(completions, answer):
# Extract first number from completion
completion_clean = completion.strip()
# Simple check: does the completion contain the correct answer?
if correct_answer in completion_clean.split():
rewards.append(1.0)
else:
rewards.append(0.0)
return rewards


def main(use_hadw=True):
print("=" * 80)
if use_hadw:
print("Testing GRPO with HA-DW on MPS")
else:
print("Testing GRPO (baseline) on MPS")
print("=" * 80)

# Check MPS availability
if not torch.backends.mps.is_available():
print("⚠️ MPS not available! Falling back to CPU.")
device = "cpu"
else:
print("✓ MPS is available")
device = "mps"

# Model configuration
# Using a small model that works well on Apple Silicon
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
print(f"\n📦 Loading model: {model_name}")
print(f" Device: {device}")

# Model initialization kwargs for MPS
# Note: Using fp32 for better numerical stability on MPS
# The paper used bf16/fp32, but MPS doesn't support bf16 well
model_init_kwargs = {
"torch_dtype": torch.float32,
"device_map": None, # Let the trainer handle device placement
}
print(f" Using dtype: {model_init_kwargs['torch_dtype']}")

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

# Create dataset
print("\n📊 Creating synthetic dataset...")
dataset = create_synthetic_math_dataset(num_samples=32) # Small dataset for quick testing
print(f" Dataset size: {len(dataset)} samples")
print(f" Example prompt: {dataset[0]['prompt']}")
print(f" Example answer: {dataset[0]['answer']}")

# Configure GRPO with HA-DW
print(f"\n⚙️ Configuring GRPO{'with HA-DW' if use_hadw else ' (baseline)'}...")
config = GRPOConfig(
output_dir="./test_hadw_output",
# Model initialization
model_init_kwargs=model_init_kwargs,
# HA-DW parameters
use_hadw=use_hadw,
hadw_eta=0.1,
hadw_lambda_scale=1.0,
hadw_history_window=5, # Smaller window for small dataset
# GRPO parameters
num_generations=2, # 2 generations per prompt (must divide batch size)
max_completion_length=32,
temperature=0.7,
# Training parameters
num_train_epochs=1,
per_device_train_batch_size=2,
gradient_accumulation_steps=1,
logging_steps=1,
save_steps=100,
learning_rate=1e-6,
# Optimization - Using fp32 for numerical stability
# (MPS doesn't handle bf16 well, and fp16 can cause overflow in HA-DW)
fp16=False,
bf16=False,
remove_unused_columns=False,
# Disable features not needed for testing
report_to=[],
save_strategy="no",
)

print(f" ✓ HA-DW enabled: {config.use_hadw}")
print(f" ✓ Eta: {config.hadw_eta}")
print(f" ✓ Lambda scale: {config.hadw_lambda_scale}")
print(f" ✓ History window: {config.hadw_history_window}")
print(f" ✓ Num generations: {config.num_generations}")

# Initialize trainer
print("\n🚀 Initializing GRPO Trainer...")
try:
trainer = GRPOTrainer(
model=model_name,
reward_funcs=accuracy_reward,
args=config,
train_dataset=dataset,
processing_class=tokenizer,
)
print(" ✓ Trainer initialized successfully")
except Exception as e:
print(f" ✗ Failed to initialize trainer: {e}")
raise

# Run training
print("\n🏋️ Starting training...")
print("-" * 80)
try:
trainer.train()
print("-" * 80)
print(" ✓ Training completed successfully!")
except Exception as e:
print(f" ✗ Training failed: {e}")
raise

# Print HA-DW metrics if available
print("\n📈 HA-DW Metrics:")
if hasattr(trainer, '_metrics') and 'train' in trainer._metrics:
metrics = trainer._metrics['train']
hadw_keys = [k for k in metrics.keys() if k.startswith('hadw/')]
if hadw_keys:
for key in hadw_keys:
if metrics[key]:
values = metrics[key]
print(f" {key}:")
print(f" - First: {values[0]:.4f}")
print(f" - Last: {values[-1]:.4f}")
if len(values) > 1:
print(f" - Mean: {sum(values)/len(values):.4f}")
else:
print(" No HA-DW metrics found (this is unexpected)")

print("\n" + "=" * 80)
print("✅ Test completed successfully!")
print("=" * 80)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(description="Test GRPO with HA-DW")
parser.add_argument(
"--no-hadw",
action="store_true",
help="Disable HA-DW (run baseline GRPO for comparison)"
)
args = parser.parse_args()

main(use_hadw=not args.no_hadw)
33 changes: 33 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,39 @@ class GRPOConfig(TrainingArguments):
"This is described in the [DeepSeek-V3.2 paper](https://huggingface.co/papers/2512.02556)."
},
)
use_hadw: bool = field(
default=False,
metadata={
"help": "Whether to use History-Aware Adaptive Difficulty Weighting (HA-DW) to mitigate biased advantage "
"estimation in group-relative RL. This method was introduced in the paper [Your Group-Relative Advantage "
"Is Biased](https://huggingface.co/papers/2601.08521). When enabled, the trainer dynamically adjusts "
"advantage weights based on prompt difficulty and the model's evolving capability across batches."
},
)
hadw_eta: float = field(
default=0.1,
metadata={
"help": "Base forgetting factor (η) for HA-DW's evolving difficulty anchor. Controls the influence of "
"historical information when updating the model's capability belief. The adaptive forgetting factor "
"is computed as η_t = η * σ_t, where σ_t measures training stability. Only used when `use_hadw=True`."
},
)
hadw_lambda_scale: float = field(
default=1.0,
metadata={
"help": "Scaling factor (λ_scale) for HA-DW's reweighting function. Controls the magnitude of the "
"exponential adjustment applied to advantages. Higher values lead to stronger reweighting. "
"Only used when `use_hadw=True`."
},
)
hadw_history_window: int = field(
default=10,
metadata={
"help": "Number of recent batches (m) used to compute the standard deviation for HA-DW's adaptive "
"forgetting factor. Larger values provide more stable estimates but are less responsive to rapid "
"capability changes. Only used when `use_hadw=True`."
},
)

# Parameters that control the logging
log_completions: bool = field(
Expand Down
Loading