Skip to content

Conversation

@LeonEricsson
Copy link
Collaborator

@LeonEricsson LeonEricsson commented Jan 24, 2026

This PR:

  1. Introduces Geometric Sequence Masking as a length-invariant alternative to existing TIS/MIS importance sampling correction
  2. Refactors DeepSeek OPSM as a special case of Geometric Sequence Masking

First I want to reiterate the convention used in the GRPO trainer, and establish some notation.

$\pi_{\text{old}}$ = old_per_token_logps: log-probs from the trainer at sampling time
$\mu_{\text{old}}$ = sampling_per_token_logps: log-probs from the generation engine at sampling time
$\pi$ = per_token_logps: log-probs of the active policy being optimized

Geometric Sequence Masking

TRL's existing importance sampling correction (TIS/MIS) uses per-token ratios:

$$ \rho_t = \frac{\pi_\text{old}(y_t| x, y_{ < t })}{\mu_\text{old} (y_t | x, y_{ < t})} $$

When applied at the sequence-level (Sequence-level IS), this turns into a product of per-token ratios:

$$ \prod_{t=0}^{T} \rho_t $$

which is length-biased—systematically favoring shorter sequences in masking/truncation decisions. To counteract this, one can use the geometric mean of the importance ratios which will measure the average per-token divergence, independent of length (shown here in log-space also).

$$ \rho_\text{geo} = \left( \prod_{t=0}^{T} \rho_t \right)^{1/T} \quad \Rightarrow \quad \log \rho_\text{geo} = \frac{1}{T} \sum_{t=0}^T \log \rho_t $$

This PR uses the geometric mean of the importance ratio to implement a two-sided mask, very similar to how the current vllm_importance_sampling_mode = "sequence_mask" mode is implemented. This is as simple as normalizing the sequence-level ratio by length

if self.vllm_importance_sampling_mode == "sequence_geometric_mean_mask":
per_sequence_logps_diff = per_sequence_logps_diff / mask.sum(dim=-1, keepdim=True).clamp(
min=1.0
)

DeepSeek OPSM as Geometric Sequence Masking

DeepSeek's Off Policy Sequence Masking technique is a form of Geometric Sequence Masking to address both training-inference mismatch and policy staleness (see discussion). It can be shown to be equivalent to negated Geometric Masking:

$$ \underbrace{\frac{1}{T} \sum_{t=0}^T \log \left( \frac{\mu_\text{old} (y_t | x, y_{ &lt; t})}{\pi(y_t| x, y_{ &lt; t })} \right)}_{\text{DeepSeek OPSM}} = - \underbrace{\frac{1}{T} \sum_{t=0}^T \log \left( \frac{\pi(y_t| x, y_{ &lt; t })}{\mu_\text{old} (y_t | x, y_{ &lt; t})} \right)}_{\text{Geo Mask}} $$

However, this expression conflates the training-inference mismatch, with the policy staleness, solving for both sources of off-policyness at the same time. This differs from TRL's existing IS correction (between $\pi_{\text{old}}$ and $\mu_{\text{old}}$) which only addresses training-inference mismatch. To express OPSM in terms of Geometric Sequence Masking, we factor the IS weight into training-inference mismatch and policy staleness:

$$ = - \frac{1}{T} \sum_{t=0}^T \log \left( \underbrace{\frac{\pi_\text{old}(y_t| x, y_{ &lt; t })}{\mu_\text{old} (y_t | x, y_{ &lt; t})}}_{\text{training-inference} == \rho_t} * \underbrace{\frac{\pi(y_t| x, y_{ &lt; t })}{\pi_\text{old}(y_t| x, y_{ &lt; t })}}_{\text{policy staleness}} \right ) $$

$$ = - \left ( \underbrace{ \frac{1}{T} \sum_{t=0}^T \log \left( \frac{\pi_\text{old}(y_t| x, y_{ &lt; t })}{\mu_\text{old} (y_t | x, y_{ &lt; t})} \right )}_{\log \rho_\text{geo}} + \frac{1}{T} \sum_{t=0}^T \log \left ( \frac{\pi(y_t| x, y_{ &lt; t })}{\pi_\text{old}(y_t| x, y_{ &lt; t })} \right ) \right ) (&gt; \delta, \text{DS masking condition}) $$

Further reading:
https://richardli.xyz/post/rl-collapse-part3/
https://arxiv.org/abs/2512.01374

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@LeonEricsson LeonEricsson changed the title Geometric mean is [GRPO] feat: Geometric Sequence Masking Jan 24, 2026
@LeonEricsson
Copy link
Collaborator Author

LeonEricsson commented Jan 24, 2026

@casinca whenever you get the time, I'd appreciate feedback on this unification. I think it turned out fairly clean, but my concern is that it muddies OPSM's connection to the paper—it becomes harder to understand the reference implementation when comparing it to the paper, especially without this PR as context. That said, I do appreciate how it clarifies the relationship between OPSM and Geometric Sequence Masking; hopefully it helps people understand why DS implemented OPSM the way they did.

Details are still WIP, just want your take on whether this direction makes sense. The alternative is to completely separate the logic of geo-mask and OPSM.

@qgallouedec if you're interested as well

Comment on lines +1983 to +1985
per_sequence_logps_diff = per_sequence_logps_diff / mask.sum(dim=-1, keepdim=True).clamp(
min=1.0
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just minor, not sure what Quentin will think about it, but since it's now a mean, the per_sequence_logps_diff var name might sound misleading, since depending on the path, it can be either a sum or a mean

Comment on lines +2287 to +2290

if raw_importance_sampling_ratio is not None:
avg_seq_kl = avg_seq_kl + raw_importance_sampling_ratio

Copy link
Contributor

@casinca casinca Jan 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo, a comment would be good to explain, that if raw_importance_sampling_ratio is not None then it ends up using vllm logprobs sampling_per_token_logps otherwise the local old_per_token_logps for the ratio

@casinca
Copy link
Contributor

casinca commented Jan 25, 2026

The logic looks good to me and checks out with the math. I agree it's harder to understand compared to the eq.9 with the paper since you now use 2 log ratios to create the proper KL mean for OPSM.

π = old_per_token_logps: log-probs of the active policy being optimized

Isn't it only the case when num of grad updates = 1 in TRL? did you meant per_token_logps

I added 2 minor suggestions (comments) but it's just my opinion.

 


 

I was wondering as an alternative, if instead of having 2 log ratios to build the correct OPSM average, we could only compute either $\rho_t = \frac{\pi(y_t \mid x, y_{&lt;t})}{\mu(y_t \mid x, y_{&lt;t})}$ if sequence_geometric_mean_mask or $\rho^{-1}$ if self.off_policy_mask_threshold is not None

This way we only ever need at all time per_token_logps and sampling_per_token_logps (as sampling_per_token_logps = inputs.get("sampling_per_token_logps", old_per_token_logps))

We would get 2 paths, with in common, the KL mean calculation in a separate staticmethod that either return $\log \rho_{\text{geo}}$ or $\log \rho_{\text{opsm}}$
We would need to make sure OPSM & Geo masking aren't both set in hparams, it has to be one or the other.

# (~pseudo code)

@staticmethod
def get_kl_mean(
	per_token_logps: torch.Tensor,
	sampling_per_token_logps: torch.Tensor,
	mask: torch.Tensor,
	# here might need to be refined toggling logic between if geo or OPSM
	geo_mean=True # if False return OPSM KL mean instead
)

if geo_mean:
	kl_div = per_token_logps.detach() - sampling_per_token_logps
else: # OPSM KL
	kl_div = sampling_per_token_logps - per_token_logps.detach()

# Sequence-level Mean KL (ignoring prompt+padding)
seq_kl_sum = (kl_div * mask).sum(dim=1, keepdim=True)
avg_seq_kl = seq_kl_sum / mask.sum(dim=1, keepdim=True).clamp(min=1.0)

return avg_seq_kl

For Geo (geo_mean=True), we get just like in the blog

$$ \log \rho_{\text{geo}} = \frac{1}{T} \sum_{t=0}^{T-1} \log \frac{\pi(y_t \mid x, y_{< t})}{\mu(y_t \mid x, y_{< t})} $$

and it's ready for creating the mask with both sides/bidir conditionals:

$$ {g}_{\text{geo-mask}} = C_{\text{min}} \leq \log \rho_{\text{geo}} \leq C_{\text{max}} $$

(with $C_\text{min}$ and $C_\text{max}$ as hparams in logspace, just like $\delta$ is already for OPSM)

 
 

For OPSM(geo_mean=False), we get back the 2nd conditional of eq.9 from the paper

$$ \log \rho_{\text{opsm}} = \frac{1}{T} \sum_{t=0}^{T-1} \log \frac{\mu(y_t \mid x, y_{< t})}{\pi(y_t \mid x, y_{< t})} $$

we remove the KL mean block from the original get_off_policy_mask (since it's now computed in the method above) and pass $\log \rho_{\text{opsm}}$ as an argument

# simplified signature: we remove mask, sampling/old and current logprobs arguments


    @staticmethod
    def get_off_policy_mask(
        advantages: torch.Tensor,
        kl_mean: torch.Tensor, # returned var from get_kl_mean(geo_mean=False)
        off_policy_threshold: float,
    ) -> torch.Tensor:

 

Initially when I thought about refactoring for both Geo and OPSM to coexist this is how I was seeing it. Does it look legit to you?

If I didn't make mistake in this approach, the good thing is that OPSM is kept intact and easier to understand with eq.9 from the paper vs having 2 log ratios to create the proper KL mean. There are 2 separate paths (Geo or OPSM) with in common the KL mean logic calculation.


Note: Not sure how you want to name the lower and upper bounds for geo masking here but since we already (at least temporarily) set C_min and C_max for IcePop style IS masking in my other PR, this might create confusion, it's a detail but mentioning in case.

@LeonEricsson
Copy link
Collaborator Author

LeonEricsson commented Jan 25, 2026

The logic looks good to me and checks out with the math. I agree it's harder to understand compared to the eq.9 with the paper since you now use 2 log ratios to create the proper KL mean for OPSM.

π = old_per_token_logps: log-probs of the active policy being optimized

Isn't it only the case when num of grad updates = 1 in TRL? did you meant per_token_logps

ye! typo, thanks

I was wondering as an alternative, if instead of having 2 log ratios to build the correct OPSM average, we could only compute either ρ t = π ( y t ∣ x , y < t ) μ ( y t ∣ x , y < t ) if sequence_geometric_mean_mask or ρ − 1 if self.off_policy_mask_threshold is not None

This way we only ever need at all time per_token_logps and sampling_per_token_logps (as sampling_per_token_logps = inputs.get("sampling_per_token_logps", old_per_token_logps))

(not sure why quoting screws up the tex math mode)

I think you've overlooked that sequence_geometric_mean_mask doesn't use the current policy $\pi$ to calculate its IS weight, our IS correction methods only target the training-inference mismatch through:

$$ \frac{\pi_\text{old}(y_t| x, y_{ < t})}{\mu_\text{old}(y_t | x, y_{ < t})} $$

We could use a variant of your proposed get_kl_mean() with $\pi_{old}$ as input for geo-mask calls, but since we don't use $\pi$ at all, the ratios can (and should to avoid recomputation) be computed at different times: the sequence_geometric_mean_mask KL can be calculated once per rollout generation in _generate_and_score_completions(), whereas the OPSM KL needs to be computed every mini-batch/gradient step with the current $\pi$.

If we want to keep the OPSM function intact, I think the cleaner path is to introduce geo-mask as done in this PR

if self.vllm_importance_sampling_mode == "sequence_geometric_mean_mask":
per_sequence_logps_diff = per_sequence_logps_diff / mask.sum(dim=-1, keepdim=True).clamp(
min=1.0
)

and leave OPSM completely independent of sequence_geometric_mean_mask (meaning leave as is on main), perhaps with a warning that combining OPSM with another correction method isn't recommended.

That said, I still lean toward my original proposal. Paired with some documentation, I think it makes clearer what OPSM is actually doing, even if it doesn't map directly to the paper

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants