Skip to content

Commit 4fb9a76

Browse files
committed
original
0 parents  commit 4fb9a76

File tree

7 files changed

+831
-0
lines changed

7 files changed

+831
-0
lines changed

AccCalu.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# -*-coding:utf-8-*-
2+
# import os
3+
import os.path
4+
import random
5+
import sys
6+
7+
import numpy as np
8+
import matplotlib.pyplot as plt
9+
10+
11+
# 绘制roc曲线
12+
def plot_roc(predict, groundTruth):
13+
predictArr = np.array(predict)
14+
groundArr = np.array(groundTruth)
15+
16+
pos_num = np.sum(groundArr == 1)
17+
neg_num = np.sum(groundArr == 0)
18+
19+
m = len(groundTruth)
20+
21+
index = predictArr.flatten().argsort()
22+
sorted_predict = np.sort(predictArr.flatten());
23+
24+
groundArr = groundArr[index]
25+
x = np.zeros(m+1)
26+
y = np.zeros(m+1)
27+
yoden = np.zeros(m)
28+
auc = 0.0
29+
x[0] = 1
30+
y[0] = 1
31+
yoden[0] = 0
32+
33+
for i in range(1, m):
34+
TP = float(np.sum(groundArr[i:] == 1))
35+
FP = float(np.sum(groundArr[i:] == 0))
36+
x[i] = FP / neg_num
37+
y[i] = TP / pos_num
38+
auc += (y[i] + y[i-1]) * (x[i-1] - x[i]) / 2
39+
yoden[i] = y[i] + (1 - x[i]) - 1
40+
x[m] = 0
41+
y[m] = 0
42+
auc += y[m - 1] * x[m - 1] / 2
43+
44+
print 'best thresh value = ', sorted_predict[np.argmax(yoden)]
45+
# fp = float(np.sum(predictArr.flatten()[0:3000] >= sorted_predict[np.argmax(yoden)]))
46+
# fn = float(np.sum(predict[:] < sorted_predict[np.argmax(yoden)]))
47+
# print 'test acc = ', (fp + fn) / (len(predict) + len(groundTruth))
48+
49+
plt.title("ROC curve of %s (AUC = %.4f)" % ('face', auc))
50+
plt.xlabel("False Positive Rate")
51+
plt.ylabel("True Positive Rate")
52+
plt.plot(x, y) # use pylab to plot x and y
53+
plt.show() # show the plot on the screen
54+
return sorted_predict[np.argmax(yoden)]
55+
56+
57+
def GetData(txt_yes, txt_no):
58+
matchList = []
59+
dismatchList = []
60+
wholelist = []
61+
groundTrth = []
62+
try:
63+
fileObj = open(txt_yes)
64+
#dismatchlist = []
65+
for line in fileObj.readlines():
66+
curLine = line.strip().split()
67+
if len(curLine) == 1:
68+
score = float(curLine[0])
69+
matchList.append(score)
70+
wholelist.append(score)
71+
groundTrth.append(1)
72+
else:
73+
print 'txt format is invalid'
74+
except IOError:
75+
print txt_yes, 'is not exits'
76+
else:
77+
print 'succeed'
78+
79+
try:
80+
fileObj = open(txt_no)
81+
for line in fileObj.readlines():
82+
curLine = line.strip().split()
83+
if len(curLine) == 1:
84+
score = float(curLine[0])
85+
wholelist.append(score)
86+
dismatchList.append(score)
87+
groundTrth.append(0)
88+
else:
89+
print 'txt format is invalid'
90+
except IOError:
91+
print txt_no, 'is not exits'
92+
else:
93+
print 'succeed'
94+
return wholelist, groundTrth,matchList,dismatchList
95+
96+
97+
if __name__ == '__main__':
98+
99+
# the first param is rootdir
100+
# the second param is thresh value
101+
txt_yes = r'D:\LWF_Yes.txt'
102+
txt_no = r'D:\LWF_No.txt'
103+
falseNum = 0
104+
wholelist,groundtruth, matchList,dismatchList = GetData(txt_yes, txt_no)
105+
print len(wholelist)
106+
print len(groundtruth)
107+
print len(matchList)
108+
print len(dismatchList)
109+
value_thresh = plot_roc(wholelist, groundtruth)
110+
for i in matchList:
111+
if i < value_thresh:
112+
falseNum += 1
113+
for i in dismatchList:
114+
if i > value_thresh:
115+
falseNum += 1
116+
print float(falseNum)/len(wholelist)
117+
print 1 - float(falseNum)/len(wholelist)
118+
print falseNum
119+
120+

0 commit comments

Comments
 (0)