Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

For nucleus sampling, top-p sampling appears to happen on the softmax-normalized top-k logits #1250

Open
j-frei opened this issue Jul 3, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@j-frei
Copy link

j-frei commented Jul 3, 2024

Describe the bug
This issue refers to the code line at:
https://github.com/EleutherAI/gpt-neox/blob/1cee5b7c7074302de4867ad5cac3f1ea26f7a7d7/megatron/text_generation_utils.py#L100C43-L100C50

To my understanding, the top-p should be applied on the pre-top-k-filtered token probabilities.
Apparently though, if top-k and top-p is enabled, the top-p part is applied based on the post-top-k-filtered logits, since an additional softmax is used here on the updated logit values.

Expected behavior
Given a large top-p value and a very small top-k value (k > 1), the top-p part should have no effect.

If, contrary to my intuition, this current implementation indeed follows your intuition about the expected behavior of nucleus sampling, you can ignore this issue.

@j-frei j-frei added the bug Something isn't working label Jul 3, 2024
@j-frei
Copy link
Author

j-frei commented Jul 3, 2024

To my understanding, the function should be rather using the input logits for top_p to determine the masked tokens.

def filter_logits(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")):
    """
    Filters the logits using top_k / top_p, filling any filtered vocab items with filter_value (defaults to -inf).

    This function has been mostly taken from huggingface conversational ai code at
    https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313

    logits: torch.Tensor -> logits of megatron model.
    top_k: integer -> integer between 0 and the models vocab size. Filters out any logits with a probability less than that of the top_kth token.
    top_p: float -> Top-p (nucleus) sampling chooses from the smallest possible set of tokens whose cumulative probability exceeds the probability top_p.

    returns: (filtered) logits"""

    masked_logits = logits.clone()
    if top_k > 0:
        # Remove all tokens with a probability less than the
        # last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        masked_logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        # convert to 1D
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token
        # above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        for i in range(sorted_indices.size(0)):
            indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
            masked_logits[i][indices_to_remove] = filter_value

    return masked_logits

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant