Skip to content

Conversation

@anonx3247
Copy link

@anonx3247 anonx3247 commented Jan 20, 2026

Summary

This PR implements History-Aware Adaptive Difficulty Weighting (HA-DW) for the GRPO trainer, based on the paper "Your Group-Relative Advantage Is Biased".

Problem

The paper identifies a fundamental issue in group-based RL: the group-relative advantage estimator is inherently biased:

  • Underestimates advantages for hard prompts (p_t < 0.5)
  • Overestimates advantages for easy prompts (p_t > 0.5)
  • Only unbiased when p_t = 0.5

This systematic bias causes the policy to under-learn from hard questions while over-exploiting easy ones, ultimately hurting both training stability and generalization.

Solution

HA-DW addresses this bias through two key components:

1. Evolving Difficulty Anchor

Tracks the model's solving capability across batches using a Kalman-style update:

C_t^+ = (1 - η_t) * C_t^- + η_t * y_t

where η_t = η * σ_t adapts to training stability.

2. Adaptive Reweighting

Computes reweighting factors that correct biased advantage estimates:

Φ_{t,i} = λ_scale * exp(D_{t,i} * M_t)

where:

  • D_{t,i} = -sgn(Â_{t,i}) * sgn(diff^his_t) determines the direction of adjustment
  • M_t = |ˆp_t - C_t^-| quantifies prompt difficulty deviation

Changes

GRPOConfig

Added four new hyperparameters (all opt-in, default disabled):

  • use_hadw (bool, default=False): Enable/disable HA-DW
  • hadw_eta (float, default=0.1): Base forgetting factor for capability updates
  • hadw_lambda_scale (float, default=1.0): Scaling factor for reweighting
  • hadw_history_window (int, default=10): Window for computing training stability

GRPOTrainer

  • Added state tracking for capability belief and history buffer
  • Implemented _compute_hadw_reweighting() method to compute reweighting factors
  • Integrated HA-DW into _generate_and_score_completions() advantage computation
  • Added comprehensive logging for HA-DW metrics
  • Added fp16-specific numerical stability safeguards

Results from Paper

The paper demonstrates consistent improvements when HA-DW is integrated with GRPO and its variants:

Model Algorithm MATH500 AIME25 AMC23 Minerva OlympiadBench AVG
Qwen-3-4B GRPO 75.4 19.6 60.3 33.8 43.5 46.5
Qwen-3-4B GRPO+HA-DW 78.0 20.4 63.4 36.8 44.7 48.7
Qwen-3-4B GSPO 75.8 20.0 62.2 35.3 42.3 47.1
Qwen-3-4B GSPO+HA-DW 77.6 19.6 68.6 37.1 43.2 49.2
Qwen-3-4B DAPO 76.8 18.3 60.0 35.7 43.2 46.8
Qwen-3-4B DAPO+HA-DW 78.6 21.3 65.0 37.5 45.3 49.5

Similar improvements are observed on Qwen-3-8B and LLaMA-3.2-3B models.

Testing

Local Testing on Apple Silicon (MPS)

We successfully tested the implementation on Apple Silicon with the following setup:

  • Model: Qwen2.5-0.5B-Instruct
  • Device: MPS (Apple M-series GPU)
  • Precision: fp32 (matches paper's experimental setup)
  • Dataset: 32 synthetic math problems (simple addition)
  • Batch size: 2 samples
  • Generations: 2 per prompt

Results:
✅ HA-DW successfully integrated and functional
✅ Adaptive reweighting activated when batch had mixed results (50% accuracy)
✅ Model learned successfully (0% → 100% accuracy)
✅ No numerical instability or crashes
✅ All HA-DW metrics logged correctly:

  • hadw/capability_prior and hadw/capability_posterior tracked model evolution
  • hadw/eta_t showed adaptive forgetting factor
  • hadw/reweighting_mean = 1.13, hadw/reweighting_std = 0.74 when activated

Example HA-DW activation (from training logs):

# Step 6 - Mixed batch (50% correct, 50% incorrect)
'hadw/batch_accuracy': 0.5
'hadw/reweighting_mean': 1.1276259422302246  # Active reweighting
'hadw/reweighting_std': 0.7369400262832642   # Differentiation between samples

Numerical Stability

  • Added fp16-specific clamping to prevent exponential overflow
  • Test script uses fp32 by default (matches paper's setup)
  • Successfully handles edge cases (all correct/incorrect batches)

Test Script Included

  • test_hadw_grpo.py: Standalone test script for quick validation
  • TEST_HADW_README.md: Comprehensive testing documentation
  • Can run with/without HA-DW for comparison (--no-hadw flag)

Backward Compatibility

This implementation is fully backward compatible:

  • HA-DW is disabled by default (use_hadw=False)
  • Existing code and training scripts work without modification
  • No performance impact when HA-DW is disabled
  • Existing tests pass without modification

Usage Example

from trl import GRPOTrainer, GRPOConfig

config = GRPOConfig(
    use_hadw=True,           # Enable HA-DW
    hadw_eta=0.1,            # Base forgetting factor
    hadw_lambda_scale=1.0,   # Reweighting scale
    hadw_history_window=10,  # History window
    # ... other GRPO parameters
)

trainer = GRPOTrainer(
    model="Qwen/Qwen2.5-0.5B-Instruct",
    reward_funcs=accuracy_reward,
    args=config,
    train_dataset=dataset,
)
trainer.train()

References

anonx3247 and others added 5 commits January 19, 2026 12:05
This commit implements HA-DW as described in the paper "Your Group-Relative
Advantage Is Biased" (https://huggingface.co/papers/2601.08521).

The paper identifies a fundamental issue in group-based RL: the group-relative
advantage estimator is inherently biased, systematically underestimating
advantages for hard prompts (p_t < 0.5) and overestimating them for easy
prompts (p_t > 0.5).

HA-DW addresses this by:
1. Tracking an evolving difficulty anchor (C_t) that captures the model's
   solving capability across batches using a Kalman-style update
2. Computing adaptive reweighting factors (Φ_{t,i}) that adjust advantage
   estimates based on prompt difficulty relative to the model's current
   capability
3. Applying these factors to correct the systematic bias in advantage
   estimation

Key changes:
- Added HA-DW hyperparameters to GRPOConfig:
  - use_hadw: Enable/disable HA-DW (default: False)
  - hadw_eta: Base forgetting factor for capability belief updates (default: 0.1)
  - hadw_lambda_scale: Scaling factor for reweighting (default: 1.0)
  - hadw_history_window: Window size for computing training stability (default: 10)

- Added HA-DW state tracking to GRPOTrainer:
  - _hadw_capability_prior: Prior belief of model's solving capability
  - _hadw_history_buffer: History of recent capability beliefs

- Implemented _compute_hadw_reweighting method that:
  - Computes batch accuracy and updates capability belief
  - Calculates reweighting factors using exponential adjustment
  - Logs HA-DW metrics for monitoring

- Integrated HA-DW into advantage computation in _generate_and_score_completions

The implementation is opt-in (disabled by default) to maintain backward
compatibility. When enabled, HA-DW consistently improves performance across
GRPO and its variants (GSPO, DAPO) on mathematical reasoning benchmarks.

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
- Add fp16-specific clamping in HA-DW exponential to prevent overflow
- Update test script to use fp32 instead of fp16 for better stability on MPS
- fp32 matches paper's experimental setup (which used fp32/bf16)

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
@qgallouedec
Copy link
Member

Thanks for the PR.

The method adds a good amount of complexity, including a buffer which I' not sure right now if it would work in distributed setup.
Considering that the method is pretty new, and that the claimed gains are not massive, I'd prefer to have it in experimental, and only consider have it in the GRPOTrainer when the claimed results are reproduced.
Also, for this PR to be merged, it will need to align better with the repo: tests in tests, not in a standalone file, add a section in the paper index, proper documentation of the new param in the config class etc.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants