Skip to content

Conversation

@Xiuyu-Li
Copy link
Contributor

This PR adds temperature as a configurable argument in Config, allowing users to adjust the sampling temperature used for the RL training. It also updates the math_rl example (tinker_cookbook/recipes/math_rl/train.py) to demonstrate how to easily apply this setting within a training script.

Note: This change only affects the TinkerTokenCompleter class, which is the default completer used by most RL recipe scripts. The TinkerMessageCompleter class (in tinker_cookbook/completers.py) remains unchanged — it still uses a hard-coded temperature=1.0 in its SamplingParams, left as-is for now to avoid altering potential design-specific behavior. As a result, this update does not impact scripts that rely on TinkerMessageCompleter, such as text_arena, twenty_questions, and similar examples.

Copy link
Collaborator

@Tiiiger Tiiiger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I'll click on merge once I can confirm the backend actually supports this correctly.

Copy link
Collaborator

@joschu joschu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure logprobs are calculated the right way when temperature != 1 -- for this to make sense, we'd need for the sampling engine to compute the logprob of the temperature-scaled model, not the original model.
Have you done experiments @Xiuyu-Li?
Marking as "Requesting Changes" until we've validated that this makes sense.

@Xiuyu-Li
Copy link
Contributor Author

Xiuyu-Li commented Nov 16, 2025

I'm not sure logprobs are calculated the right way when temperature != 1 -- for this to make sense, we'd need for the sampling engine to compute the logprob of the temperature-scaled model, not the original model. Have you done experiments @Xiuyu-Li? Marking as "Requesting Changes" until we've validated that this makes sense.

Thanks for raising this — I checked with a few short runs at temperature=0.6 and observed the expected behavior (lower entropy than temperature=1.0 under the same setup) earlier, but I did not do a comprehensive test.

To validate the correctness more rigorously, I spent some time designing a frontend-side test in tinker_cookbook/tests/validate_temperature_logprobs.py. The script verifies that sample_async returns logprobs from the temperature-scaled distribution, not the base model distribution, using two complementary checks:

First, at a fixed context, the script repeatedly samples the first token across a configurable grid of temperatures (default 0.5–1.8), recording log p_τ(token) from sample_async with max_tokens = 1. For all unique tokens ever sampled, it then calls compute_logprobs_async once at τ = 1.0 to get log p_1(token | prompt). Finally, it computes pairwise differences
Δ_τ(i, j) = log p_τ(i) − log p_τ(j) and Δ_1(i, j) = log p_1(i) − log p_1(j), and checks that Δ_τ(i, j) ≈ (1/τ) × Δ_1(i, j) over all token pairs. It reports the mean and maximum absolute error for each temperature.

While the temperature-scaling check validates correctness for max_tokens = 1, we also want to ensure that sample_async behaves consistently when generating multi-token sequences (max_tokens > 1). To test this, at a fixed temperature (default 0.5) the script samples a sequence with max_tokens > 1 (default length 20). For each token position, it resamples with max_tokens = 1 until the same token is observed and compares the returned logprob.

These two checks together validate temperature handling in sample_async without requiring access to logits computed by the backend.

Results

Validated on both Llama-3.2-1B and Qwen3-4B-Instruct-2507.

python -m tinker_cookbook.tests.validate_temperature_logprobs base_model=meta-llama/Llama-3.2-1B

gives:

===========================================================================
TEMPERATURE SCALING VALIDATION
===========================================================================
Model: meta-llama/Llama-3.2-1B, 20 trials per temperature
    Temp    Unique Tokens     Pairs     Mean Diff      Max Diff
---------------------------------------------------------------------------
   0.500               10        45      0.000000      0.000000
   0.700               16       120      0.000001      0.000002
   1.000               18       153      0.000000      0.000000
   1.200               19       171      0.000000      0.000001
   1.500               20       190      0.000000      0.000001
   1.800               19       171      0.000000      0.000001

===========================================================================
SEQUENCE-LEVEL CONSISTENCY CHECK (multi-token logprob validation)
===========================================================================
Generate with max_tokens=20 at temp=0.5, then resample each position individually to verify logprob consistency.
    Temp    Length   Matches     Mean Diff      Max Diff
---------------------------------------------------------------------------
   0.500        20        20      0.000004      0.000074

and

python -m tinker_cookbook.tests.validate_temperature_logprobs base_model=meta-llama/Llama-3.2-1B

gives:

===========================================================================
TEMPERATURE SCALING VALIDATION
===========================================================================
Model: Qwen/Qwen3-4B-Instruct-2507, 20 trials per temperature
    Temp    Unique Tokens     Pairs     Mean Diff      Max Diff
---------------------------------------------------------------------------
   0.500                2         1      0.000000      0.000000
   0.700                2         1      0.000293      0.000293
   1.000                7        21      0.000000      0.000000
   1.200                6        15      0.000000      0.000001
   1.500               13        78      0.000000      0.000001
   1.800               14        91      0.000741      0.005185

===========================================================================
SEQUENCE-LEVEL CONSISTENCY CHECK (multi-token logprob validation)
===========================================================================
Generate with max_tokens=20 at temp=0.5, then resample each
position individually to verify logprob consistency.
    Temp    Length   Matches     Mean Diff      Max Diff
---------------------------------------------------------------------------
   0.500        20        20      0.005104      0.050266

Both models show the expected behavior: temperature scaling aligns with the 1/τ relationship using one-token sampling, and per-step logprobs remain consistent with repeated one-token sampling. Small numerical discrepancies are observed among models and runs; these could potentially stem from backend or implementation-level differences, though I don’t have enough visibility to confirm. They appear to fall within a reasonable tolerance based on the tests above.

Looking forward to hearing what you and the team think! @joschu

@joschu
Copy link
Collaborator

joschu commented Nov 17, 2025

Hi @Xiuyu-Li, thanks for going to the trouble of checking this. I'm now convinced that you're correct.

I also took a look at the backend code and confirmed that it appears that we're returning the same logprobs used for sampling, even when temperature scaling is applied.

Changing the temperature is still a bit sketchy and might lead to wrong results in certain cases, e.g. right now, compute_logprobs doesn't accept a temperature parameter, so KL penalties will not work correctly. So we should warn people against changing the parameter. But it's good to have the option.

@Xiuyu-Li
Copy link
Contributor Author

Sounds good. Thanks for helping add the warning comment!

@joschu joschu merged commit 2ec1bcb into thinking-machines-lab:main Nov 17, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants