You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When pad_token is eos token, what is calculated using this method will be the position of the first eos token. However, in some LLM templates, eos is often added at the end of the prompt. Also, when a model such llama does not have a default pad token, using eos as a pad token is a common practice.
For example, in llama, eos token is <|eot_id|>
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
tokenizer.eos_token, model.config.pad_token_id, tokenizer.pad_token
# ('<|eot_id|>', None, None)
In Llama chat template, eos_token is added on the end of prompt. If we do not pad the input_ids and don't set pad_token_id, the score is correct.
with torch.no_grad():
score1 = model(conv1_tokenized).logits[0][0].item()
score2 = model(conv2_tokenized).logits[0][0].item()
print(f"Score for response 1: {score1}")
print(f"Score for response 2: {score2}")
# Score for response 1: 1.7297050952911377
# Score for response 2: 1.43972647190094
If we set pad_token_id = eos_token_id, we get the same score for different QA pairs with the same prompt
model.config.pad_token_id = tokenizer.eos_token_id
with torch.no_grad():
score1 = model(conv1_tokenized).logits[0][0].item()
score2 = model(conv2_tokenized).logits[0][0].item()
print(f"Score for response 1: {score1}")
print(f"Score for response 2: {score2}")
# Score for response 1: -1.857212781906128
# Score for response 2: -1.857212781906128
This is because the score of SequenceClassification Model is the last non-pad token's logits,and the last non-pad token is the token before the first eos token. This is incorrect especially when training reward models with preference pairs that have the same prompt.
This bug makes sense to me! I think very few people are finetuning chat models for SequenceClassification, but that doesn't mean we shouldn't support it. I agree that indexing on the EOS token should only be used when no attention mask is available.
Would you be willing to make a PR to modify the SequenceClassification behaviour?
System Info
None
Who can help?
@ArthurZucker
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
SequenceClassification Model finds the last token that is not a padding token in each row by cal the first position of pad token when a
pad_token_id
is defined in the configuration.the code in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L951
When pad_token is eos token, what is calculated using this method will be the position of the first eos token. However, in some LLM templates, eos is often added at the end of the prompt. Also, when a model such llama does not have a default pad token, using eos as a pad token is a common practice.
For example, in llama, eos token is <|eot_id|>
Using chat template to encode QA pairs we can get
In Llama chat template, eos_token is added on the end of prompt. If we do not pad the input_ids and don't set pad_token_id, the score is correct.
If we set pad_token_id = eos_token_id, we get the same score for different QA pairs with the same prompt
This is because the score of SequenceClassification Model is the last non-pad token's logits,and the last non-pad token is the token before the first eos token. This is incorrect especially when training reward models with preference pairs that have the same prompt.
The complete recurrence script is as follows
Expected behavior
Maybe we should prioritize using attention mask to calculate the position of the last non-pad token.
The text was updated successfully, but these errors were encountered: