-
Notifications
You must be signed in to change notification settings - Fork 178
Add configurable temperature parameter for RL rollout sampling #86
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Tiiiger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I'll click on merge once I can confirm the backend actually supports this correctly.
joschu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure logprobs are calculated the right way when temperature != 1 -- for this to make sense, we'd need for the sampling engine to compute the logprob of the temperature-scaled model, not the original model.
Have you done experiments @Xiuyu-Li?
Marking as "Requesting Changes" until we've validated that this makes sense.
Thanks for raising this — I checked with a few short runs at To validate the correctness more rigorously, I spent some time designing a frontend-side test in First, at a fixed context, the script repeatedly samples the first token across a configurable grid of temperatures (default 0.5–1.8), recording While the temperature-scaling check validates correctness for These two checks together validate temperature handling in ResultsValidated on both Llama-3.2-1B and Qwen3-4B-Instruct-2507. python -m tinker_cookbook.tests.validate_temperature_logprobs base_model=meta-llama/Llama-3.2-1Bgives: and python -m tinker_cookbook.tests.validate_temperature_logprobs base_model=meta-llama/Llama-3.2-1Bgives: Both models show the expected behavior: temperature scaling aligns with the 1/τ relationship using one-token sampling, and per-step logprobs remain consistent with repeated one-token sampling. Small numerical discrepancies are observed among models and runs; these could potentially stem from backend or implementation-level differences, though I don’t have enough visibility to confirm. They appear to fall within a reasonable tolerance based on the tests above. Looking forward to hearing what you and the team think! @joschu |
|
Hi @Xiuyu-Li, thanks for going to the trouble of checking this. I'm now convinced that you're correct. I also took a look at the backend code and confirmed that it appears that we're returning the same logprobs used for sampling, even when temperature scaling is applied. Changing the temperature is still a bit sketchy and might lead to wrong results in certain cases, e.g. right now, compute_logprobs doesn't accept a temperature parameter, so KL penalties will not work correctly. So we should warn people against changing the parameter. But it's good to have the option. |
|
Sounds good. Thanks for helping add the warning comment! |
This PR adds temperature as a configurable argument in
Config, allowing users to adjust the sampling temperature used for the RL training. It also updates the math_rl example (tinker_cookbook/recipes/math_rl/train.py) to demonstrate how to easily apply this setting within a training script.Note: This change only affects the
TinkerTokenCompleterclass, which is the default completer used by most RL recipe scripts. TheTinkerMessageCompleterclass (intinker_cookbook/completers.py) remains unchanged — it still uses a hard-codedtemperature=1.0in itsSamplingParams, left as-is for now to avoid altering potential design-specific behavior. As a result, this update does not impact scripts that rely onTinkerMessageCompleter, such astext_arena,twenty_questions, and similar examples.