Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extracting the logits of TruthX model #3

Open
adam-wiacek opened this issue Apr 26, 2024 · 0 comments
Open

Extracting the logits of TruthX model #3

adam-wiacek opened this issue Apr 26, 2024 · 0 comments

Comments

@adam-wiacek
Copy link

adam-wiacek commented Apr 26, 2024

I used code from your example and generate method of llama2-chat-7B-TruthX produces different output than base llama2-chat-7B. However I have a problem when I try to extract the logits (and token probabilities) of llama2-chat-7B-TruthX... In probabilities, there is almost no difference to the base llama. With outputs generated by the model being so different, the difference in token probabilities (between base and TruthX model) should also be significant. Could you help me on that?

That's the code that I use to extract the token probabilites and save them to file.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# comment / uncomment to get probs of TruthX / base model
llama2chat = "Llama-2-7b-chat-TruthX"  # downloaded locally from 'https://huggingface.co/ICTNLP/Llama-2-7b-chat-TruthX'
# llama2chat = "daryl149/llama-2-7b-chat-hf"

tokenizer = AutoTokenizer.from_pretrained(
    llama2chat, trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
    llama2chat_with,
    low_cpu_mem_usage=True,
    torch_dtype=torch.float16,
    trust_remote_code=True,
).cuda()

question = "What are the benefits of eating an apple a day?"

# using TruthfulQA prompt
from llm import PROF_PRIMER as TRUTHFULQA_PROMPT

encoded_inputs = tokenizer(TRUTHFULQA_PROMPT.format(question), return_tensors="pt")[
    "input_ids"
]
encoded_inputs = tokenizer(question, return_tensors="pt")["input_ids"]
outputs = model.generate(encoded_inputs.cuda(), max_new_tokens=4000)[0, encoded_inputs.shape[-1] :]
outputs_text = (
    tokenizer.decode(outputs, skip_special_tokens=True).split("Q:")[0].strip()
)
print(outputs_text)

# save probs over tokens
with torch.no_grad():
    logits = model(encoded_inputs.cuda()).logits.cpu().type(torch.float32)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    torch.save(probs, f"{llama2chat.replace('/', '_')}.pt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant