-
Notifications
You must be signed in to change notification settings - Fork 51
/
predict.py
103 lines (87 loc) · 3.67 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import numpy as np
import json
import sys
from sklearn.metrics import roc_auc_score
from data_loader import ValTestDataLoader
from model import Net
# can be changed according to config.txt
exer_n = 17746
knowledge_n = 123
student_n = 4163
def test(epoch):
data_loader = ValTestDataLoader('test')
net = Net(student_n, exer_n, knowledge_n)
device = torch.device('cpu')
print('testing model...')
data_loader.reset()
load_snapshot(net, 'model/model_epoch' + str(epoch))
net = net.to(device)
net.eval()
correct_count, exer_count = 0, 0
pred_all, label_all = [], []
while not data_loader.is_end():
input_stu_ids, input_exer_ids, input_knowledge_embs, labels = data_loader.next_batch()
input_stu_ids, input_exer_ids, input_knowledge_embs, labels = input_stu_ids.to(device), input_exer_ids.to(
device), input_knowledge_embs.to(device), labels.to(device)
out_put = net(input_stu_ids, input_exer_ids, input_knowledge_embs)
out_put = out_put.view(-1)
# compute accuracy
for i in range(len(labels)):
if (labels[i] == 1 and out_put[i] > 0.5) or (labels[i] == 0 and out_put[i] < 0.5):
correct_count += 1
exer_count += len(labels)
pred_all += out_put.tolist()
label_all += labels.tolist()
pred_all = np.array(pred_all)
label_all = np.array(label_all)
# compute accuracy
accuracy = correct_count / exer_count
# compute RMSE
rmse = np.sqrt(np.mean((label_all - pred_all) ** 2))
# compute AUC
auc = roc_auc_score(label_all, pred_all)
print('epoch= %d, accuracy= %f, rmse= %f, auc= %f' % (epoch, accuracy, rmse, auc))
with open('result/model_test.txt', 'a', encoding='utf8') as f:
f.write('epoch= %d, accuracy= %f, rmse= %f, auc= %f\n' % (epoch, accuracy, rmse, auc))
def load_snapshot(model, filename):
f = open(filename, 'rb')
model.load_state_dict(torch.load(f, map_location=lambda s, loc: s))
f.close()
def get_status():
'''
An example of getting student's knowledge status
:return:
'''
net = Net()
load_snapshot(net, 'model/model_epoch12') # load model
net.eval()
with open('result/student_stat.txt', 'w', encoding='utf8') as output_file:
for stu_id in range(student_n):
# get knowledge status of student with stu_id (index)
status = net.get_knowledge_status(torch.LongTensor([stu_id])).tolist()[0]
output_file.write(str(status) + '\n')
def get_exer_params():
'''
An example of getting exercise's parameters (knowledge difficulty and exercise discrimination)
:return:
'''
net = Net()
load_snapshot(net, 'model/model_epoch12') # load model
net.eval()
exer_params_dict = {}
for exer_id in range(exer_n):
# get knowledge difficulty and exercise discrimination of exercise with exer_id (index)
k_difficulty, e_discrimination = net.get_exer_params(torch.LongTensor([exer_id]))
exer_params_dict[exer_id + 1] = (k_difficulty.tolist()[0], e_difficulty.tolist()[0])
with open('result/exer_params.txt', 'w', encoding='utf8') as o_f:
o_f.write(str(exer_params_dict))
if __name__ == '__main__':
if (len(sys.argv) != 2) or (not sys.argv[1].isdigit()):
print('command:\n\tpython predict.py {epoch}\nexample:\n\tpython predict.py 70')
exit(1)
# global student_n, exer_n, knowledge_n
with open('config.txt') as i_f:
i_f.readline()
student_n, exer_n, knowledge_n = list(map(eval, i_f.readline().split(',')))
test(int(sys.argv[1]))