Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maybe the way SequenceClassification Model calculates the last non-pad token is not reasonable. #35352

Open
4 tasks
liangxuZhang opened this issue Dec 20, 2024 · 1 comment
Labels

Comments

@liangxuZhang
Copy link
Contributor

System Info

None

Who can help?

@ArthurZucker

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

SequenceClassification Model finds the last token that is not a padding token in each row by cal the first position of pad token when a pad_token_id is defined in the configuration.
the code in https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L951

sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1

When pad_token is eos token, what is calculated using this method will be the position of the first eos token. However, in some LLM templates, eos is often added at the end of the prompt. Also, when a model such llama does not have a default pad token, using eos as a pad token is a common practice.

For example, in llama, eos token is <|eot_id|>

from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
tokenizer.eos_token, model.config.pad_token_id, tokenizer.pad_token
# ('<|eot_id|>', None, None)

Using chat template to encode QA pairs we can get

prompt = "what is 1+1?"
response1 = "1+1=2"
response2 = "1+1=3"

conv1 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response1}]
conv2 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response2}]

conv1_tokenized = tokenizer.apply_chat_template(conv1, tokenize=True, return_tensors="pt")
conv2_tokenized = tokenizer.apply_chat_template(conv2, tokenize=True, return_tensors="pt")

# conv1_tokenized(input_ids)
# (tensor([[128000, 128006,    882, 128007,    271,  12840,    374,    220,     16,
#             10,     16,     30, 128009, 128006,  78191, 128007,    271,     16,
#              10,     16,     28,     17, 128009, 128006,  78191, 128007,    271]]),
# conv1_tokenized(tokens)
# '<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nwhat is 1+1?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n1+1=2<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n')

In Llama chat template, eos_token is added on the end of prompt. If we do not pad the input_ids and don't set pad_token_id, the score is correct.

with torch.no_grad():
    score1 = model(conv1_tokenized).logits[0][0].item()
    score2 = model(conv2_tokenized).logits[0][0].item()
print(f"Score for response 1: {score1}")
print(f"Score for response 2: {score2}")
# Score for response 1: 1.7297050952911377
# Score for response 2: 1.43972647190094

If we set pad_token_id = eos_token_id, we get the same score for different QA pairs with the same prompt

model.config.pad_token_id = tokenizer.eos_token_id
with torch.no_grad():
    score1 = model(conv1_tokenized).logits[0][0].item()
    score2 = model(conv2_tokenized).logits[0][0].item()
print(f"Score for response 1: {score1}")
print(f"Score for response 2: {score2}")
# Score for response 1: -1.857212781906128
# Score for response 2: -1.857212781906128

This is because the score of SequenceClassification Model is the last non-pad token's logits,and the last non-pad token is the token before the first eos token. This is incorrect especially when training reward models with preference pairs that have the same prompt.

The complete recurrence script is as follows

from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained("meta-llama/Llama-3.2-1B-Instruct", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B-Instruct")
prompt = "what is 1+1?"
response1 = "1+1=2"
response2 = "1+1=3"

conv1 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response1}]
conv2 = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response2}]

conv1_tokenized1 = tokenizer.apply_chat_template(conv1, tokenize=True, return_tensors="pt")
conv2_tokenized1 = tokenizer.apply_chat_template(conv2, tokenize=True, return_tensors="pt")
model.config.pad_token_id = tokenizer.eos_token_id
with torch.no_grad():
    score1 = model(conv1_tokenized1).logits[0][0].item()
    score2 = model(conv2_tokenized1).logits[0][0].item()
print(f"Score for response 1: {score1}")
print(f"Score for response 2: {score2}")

Expected behavior

Maybe we should prioritize using attention mask to calculate the position of the last non-pad token.

@Rocketknight1
Copy link
Member

This bug makes sense to me! I think very few people are finetuning chat models for SequenceClassification, but that doesn't mean we shouldn't support it. I agree that indexing on the EOS token should only be used when no attention mask is available.

Would you be willing to make a PR to modify the SequenceClassification behaviour?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants