Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PRM "True" probability #3

Open
rawsh opened this issue Sep 20, 2024 · 1 comment
Open

PRM "True" probability #3

rawsh opened this issue Sep 20, 2024 · 1 comment

Comments

@rawsh
Copy link

rawsh commented Sep 20, 2024

Curious why both the PRMs take the softmax probability of the True token?

class Mistral_PRM(nn.Module):
    def __init__(self, base):
        super(Mistral_PRM, self).__init__()
        self.base_model = base

    def forward(self, input_ids, attention_mask):
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask).logits
        probs = torch.softmax(outputs, dim=-1)
        output = probs[:, -1, 7081]  # n*1 tensor, 7081 is the index of token 'True'
        return output
@zhoubiansining
Copy link
Contributor

For this PRM implementation, we attempt to use the softmax probability of a single token to represent the reward ranged in [0,1]. The choice of this single token is not that important, and 'True' is a relatively natural setting. Besides, this implementation of PRM is only used for comparison in the research process, and it's not a necessary component of our approach.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants