You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
promptModel.eval()
print(promptModel)
promptModel, data_loader = accelerator.prepare(promptModel, data_loader)
promptModel.to(device)
predictions = []
with torch.no_grad():
for batch in tqdm(data_loader, desc="Processing batches"):
batch = {k: v.to(device) for k, v in batch.items()}
print(batch)
logits = promptModel(batch)
print(logits)
exit()
preds = torch.argmax(logits, dim=-1)
for i in preds:
predictions.append(i.item())
This is my code.
from datasets import load_dataset
from transformers import set_seed
from openprompt.data_utils import InputExample
import os
from tqdm import tqdm
device = "cuda"
classes = ["negative", "positive"]
set_seed(1024)
from accelerate import Accelerator
accelerator = Accelerator()
data_path = 'data'
test_path = os.path.join(data_path, 'test.json')
test_dataset = load_dataset('json', data_files=test_path)['train'] # 1 positive 0 negative
y_true = test_dataset['label']
dataset = []
import copy
data = []
copy_test_dataset = copy.deepcopy(test_dataset)
for example in copy_test_dataset:
temp_data = {"guid": example["label"], "text_a": example["sentence"]}
data.append(temp_data)
for item in data:
dataset.append(InputExample(guid=item["guid"], text_a=item["text_a"]))
from openprompt import plms
from openprompt.plms import *
from transformers import LlamaConfig, LlamaForCausalLM, LlamaTokenizer
plms._MODEL_CLASSES["llama"]= ModelClass(**{"config": LlamaConfig, "tokenizer": LlamaTokenizer, "model": LlamaForCausalLM, "wrapper": LMTokenizerWrapper})
from openprompt.plms import load_plm
plm, tokenizer, model_config, WrapperClass = load_plm("llama", "huggyllama/llama-7b")
tokenizer.pad_token_id = 0
from openprompt.prompts import ManualTemplate
promptTemplate = ManualTemplate(
text=' {"placeholder":"text_a"} This sentence was {"mask"}',
tokenizer=tokenizer,
)
from openprompt.prompts import ManualVerbalizer
promptVerbalizer = ManualVerbalizer(classes=classes,
label_words={"negative": ["bad"], "positive": ["good", "wonderful", "great"], },
tokenizer=tokenizer, )
from openprompt import PromptForClassification
promptModel = PromptForClassification(template=promptTemplate, plm=plm, verbalizer=promptVerbalizer, )
from openprompt import PromptDataLoader
data_loader = PromptDataLoader(dataset=dataset, tokenizer=tokenizer, template=promptTemplate,
tokenizer_wrapper_class=WrapperClass, batch_size=1)
import torch
promptModel.eval()
print(promptModel)
promptModel, data_loader = accelerator.prepare(promptModel, data_loader)
promptModel.to(device)
predictions = []
with torch.no_grad():
for batch in tqdm(data_loader, desc="Processing batches"):
batch = {k: v.to(device) for k, v in batch.items()}
print(batch)
logits = promptModel(batch)
print(logits)
exit()
preds = torch.argmax(logits, dim=-1)
for i in preds:
predictions.append(i.item())
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_true, predictions)
print('Accuracy: %.2f' % (accuracy * 100))
The output logits is :
tensor([[-1.3863, -1.3863]])
The text was updated successfully, but these errors were encountered: