-
Notifications
You must be signed in to change notification settings - Fork 2.5k
[GRPO] feat: Geometric Sequence Masking #4891
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…nce_sampling_max` This is done to better align both min and max bounds arguments for TIS and MIS
docs: renaming `vllm_importance_sampling_cap` to `vllm_importance_sampling_max` Co-authored-by: LeonEricsson <70749762+LeonEricsson@users.noreply.github.com>
|
@casinca whenever you get the time, I'd appreciate feedback on this unification. I think it turned out fairly clean, but my concern is that it muddies OPSM's connection to the paper—it becomes harder to understand the reference implementation when comparing it to the paper, especially without this PR as context. That said, I do appreciate how it clarifies the relationship between OPSM and Geometric Sequence Masking; hopefully it helps people understand why DS implemented OPSM the way they did. Details are still WIP, just want your take on whether this direction makes sense. The alternative is to completely separate the logic of geo-mask and OPSM. @qgallouedec if you're interested as well |
| per_sequence_logps_diff = per_sequence_logps_diff / mask.sum(dim=-1, keepdim=True).clamp( | ||
| min=1.0 | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just minor, not sure what Quentin will think about it, but since it's now a mean, the per_sequence_logps_diff var name might sound misleading, since depending on the path, it can be either a sum or a mean
|
|
||
| if raw_importance_sampling_ratio is not None: | ||
| avg_seq_kl = avg_seq_kl + raw_importance_sampling_ratio | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Imo, a comment would be good to explain, that if raw_importance_sampling_ratio is not None then it ends up using vllm logprobs sampling_per_token_logps otherwise the local old_per_token_logps for the ratio
|
The logic looks good to me and checks out with the math. I agree it's harder to understand compared to the eq.9 with the paper since you now use 2 log ratios to create the proper KL mean for OPSM.
Isn't it only the case when num of grad updates = 1 in TRL? did you meant I added 2 minor suggestions (comments) but it's just my opinion.
I was wondering as an alternative, if instead of having 2 log ratios to build the correct OPSM average, we could only compute either This way we only ever need at all time We would get 2 paths, with in common, the KL mean calculation in a separate staticmethod that either return # (~pseudo code)
@staticmethod
def get_kl_mean(
per_token_logps: torch.Tensor,
sampling_per_token_logps: torch.Tensor,
mask: torch.Tensor,
# here might need to be refined toggling logic between if geo or OPSM
geo_mean=True # if False return OPSM KL mean instead
)
if geo_mean:
kl_div = per_token_logps.detach() - sampling_per_token_logps
else: # OPSM KL
kl_div = sampling_per_token_logps - per_token_logps.detach()
# Sequence-level Mean KL (ignoring prompt+padding)
seq_kl_sum = (kl_div * mask).sum(dim=1, keepdim=True)
avg_seq_kl = seq_kl_sum / mask.sum(dim=1, keepdim=True).clamp(min=1.0)
return avg_seq_klFor Geo ( and it's ready for creating the mask with both sides/bidir conditionals: (with For OPSM( we remove the KL mean block from the original # simplified signature: we remove mask, sampling/old and current logprobs arguments
@staticmethod
def get_off_policy_mask(
advantages: torch.Tensor,
kl_mean: torch.Tensor, # returned var from get_kl_mean(geo_mean=False)
off_policy_threshold: float,
) -> torch.Tensor:
Initially when I thought about refactoring for both Geo and OPSM to coexist this is how I was seeing it. Does it look legit to you? If I didn't make mistake in this approach, the good thing is that OPSM is kept intact and easier to understand with eq.9 from the paper vs having 2 log ratios to create the proper KL mean. There are 2 separate paths (Geo or OPSM) with in common the KL mean logic calculation. Note: Not sure how you want to name the lower and upper bounds for geo masking here but since we already (at least temporarily) set |
ye! typo, thanks
(not sure why quoting screws up the tex math mode) I think you've overlooked that We could use a variant of your proposed If we want to keep the OPSM function intact, I think the cleaner path is to introduce geo-mask as done in this PR trl/trl/trainer/grpo_trainer.py Lines 1982 to 1985 in 89ac01e
and leave OPSM completely independent of That said, I still lean toward my original proposal. Paired with some documentation, I think it makes clearer what OPSM is actually doing, even if it doesn't map directly to the paper |
This PR:
First I want to reiterate the convention used in the GRPO trainer, and establish some notation.
old_per_token_logps: log-probs from the trainer at sampling timesampling_per_token_logps: log-probs from the generation engine at sampling timeper_token_logps: log-probs of the active policy being optimizedGeometric Sequence Masking
TRL's existing importance sampling correction (TIS/MIS) uses per-token ratios:
When applied at the sequence-level (Sequence-level IS), this turns into a product of per-token ratios:
which is length-biased—systematically favoring shorter sequences in masking/truncation decisions. To counteract this, one can use the geometric mean of the importance ratios which will measure the average per-token divergence, independent of length (shown here in log-space also).
This PR uses the geometric mean of the importance ratio to implement a two-sided mask, very similar to how the current
vllm_importance_sampling_mode = "sequence_mask"mode is implemented. This is as simple as normalizing the sequence-level ratio by lengthtrl/trl/trainer/grpo_trainer.py
Lines 1982 to 1985 in 89ac01e
DeepSeek OPSM as Geometric Sequence Masking
DeepSeek's Off Policy Sequence Masking technique is a form of Geometric Sequence Masking to address both training-inference mismatch and policy staleness (see discussion). It can be shown to be equivalent to negated Geometric Masking:
However, this expression conflates the training-inference mismatch, with the policy staleness, solving for both sources of off-policyness at the same time. This differs from TRL's existing IS correction (between$\pi_{\text{old}}$ and $\mu_{\text{old}}$ ) which only addresses training-inference mismatch. To express OPSM in terms of Geometric Sequence Masking, we factor the IS weight into training-inference mismatch and policy staleness:
Further reading:
https://richardli.xyz/post/rl-collapse-part3/
https://arxiv.org/abs/2512.01374
Before submitting
Pull Request section?
to it if that's the case.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.