-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict_batch.py
77 lines (62 loc) · 2.51 KB
/
predict_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix
from transformers import BertTokenizerFast, BertForSequenceClassification
import torch
import seaborn as sns
from torch.utils.data import DataLoader
from datasets import Dataset
import matplotlib.pyplot as plt
import model_config
# 加载测试集
test_data_path = model_config.test_data_path
test_data = pd.read_csv(test_data_path)
# 简单测试请指定测试集数量
# test_data = pd.read_csv(test_data_path, nrows=32)
texts = test_data["0"].tolist()
labels = test_data["1"].tolist()
print(len(texts))
model_path = model_config.model_path
model_tokenizer_path = model_config.model_name_tokenizer_path
tokenizer = BertTokenizerFast.from_pretrained(model_tokenizer_path)
model = BertForSequenceClassification.from_pretrained(model_path)
# 定义数据具体处理逻辑
def collate_fn(batch):
texts = [item["text"] for item in batch]
labels = [item["label"] for item in batch]
encoding = tokenizer(texts, padding=True, truncation=True, max_length=64, return_tensors="pt")
encoding["labels"] = torch.tensor(labels)
return encoding
batch_size = model_config.test_batch_size
# 创建Dataset对象
dataset = Dataset.from_dict({"text": texts, "label": labels})
data_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
predictions = []
for batch in data_loader:
inputs = {k: v for k, v in batch.items() if k != "labels"}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
if model.config.num_labels > 2:
# 多分类任务,取概率最高的类别
batch_predictions = torch.argmax(logits, dim=1).tolist()
else:
# 二分类任务,取大于0.5的概率作为正类
batch_predictions = (logits > 0.5).squeeze().tolist()
predictions.extend(batch_predictions)
# 计算准确度、精确度和召回率
accuracy = accuracy_score(labels, predictions)
precision = precision_score(labels, predictions, average="weighted", zero_division=0)
recall = recall_score(labels, predictions, average="weighted", zero_division=0)
# 输出结果
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
# 绘制混淆矩阵
cm = confusion_matrix(labels, predictions)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.savefig("img/confusion_matrix2.png")
plt.show()