|
| 1 | +# -*-coding:utf-8-*- |
| 2 | +# import os |
| 3 | +import os.path |
| 4 | +import random |
| 5 | +import sys |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +import matplotlib.pyplot as plt |
| 9 | + |
| 10 | + |
| 11 | +# 绘制roc曲线 |
| 12 | +def plot_roc(predict, groundTruth): |
| 13 | + predictArr = np.array(predict) |
| 14 | + groundArr = np.array(groundTruth) |
| 15 | + |
| 16 | + pos_num = np.sum(groundArr == 1) |
| 17 | + neg_num = np.sum(groundArr == 0) |
| 18 | + |
| 19 | + m = len(groundTruth) |
| 20 | + |
| 21 | + index = predictArr.flatten().argsort() |
| 22 | + sorted_predict = np.sort(predictArr.flatten()); |
| 23 | + |
| 24 | + groundArr = groundArr[index] |
| 25 | + x = np.zeros(m+1) |
| 26 | + y = np.zeros(m+1) |
| 27 | + yoden = np.zeros(m) |
| 28 | + auc = 0.0 |
| 29 | + x[0] = 1 |
| 30 | + y[0] = 1 |
| 31 | + yoden[0] = 0 |
| 32 | + |
| 33 | + for i in range(1, m): |
| 34 | + TP = float(np.sum(groundArr[i:] == 1)) |
| 35 | + FP = float(np.sum(groundArr[i:] == 0)) |
| 36 | + x[i] = FP / neg_num |
| 37 | + y[i] = TP / pos_num |
| 38 | + auc += (y[i] + y[i-1]) * (x[i-1] - x[i]) / 2 |
| 39 | + yoden[i] = y[i] + (1 - x[i]) - 1 |
| 40 | + x[m] = 0 |
| 41 | + y[m] = 0 |
| 42 | + auc += y[m - 1] * x[m - 1] / 2 |
| 43 | + |
| 44 | + print 'best thresh value = ', sorted_predict[np.argmax(yoden)] |
| 45 | + # fp = float(np.sum(predictArr.flatten()[0:3000] >= sorted_predict[np.argmax(yoden)])) |
| 46 | + # fn = float(np.sum(predict[:] < sorted_predict[np.argmax(yoden)])) |
| 47 | + # print 'test acc = ', (fp + fn) / (len(predict) + len(groundTruth)) |
| 48 | + |
| 49 | + plt.title("ROC curve of %s (AUC = %.4f)" % ('face', auc)) |
| 50 | + plt.xlabel("False Positive Rate") |
| 51 | + plt.ylabel("True Positive Rate") |
| 52 | + plt.plot(x, y) # use pylab to plot x and y |
| 53 | + plt.show() # show the plot on the screen |
| 54 | + return sorted_predict[np.argmax(yoden)] |
| 55 | + |
| 56 | + |
| 57 | +def GetData(txt_yes, txt_no): |
| 58 | + matchList = [] |
| 59 | + dismatchList = [] |
| 60 | + wholelist = [] |
| 61 | + groundTrth = [] |
| 62 | + try: |
| 63 | + fileObj = open(txt_yes) |
| 64 | + #dismatchlist = [] |
| 65 | + for line in fileObj.readlines(): |
| 66 | + curLine = line.strip().split() |
| 67 | + if len(curLine) == 1: |
| 68 | + score = float(curLine[0]) |
| 69 | + matchList.append(score) |
| 70 | + wholelist.append(score) |
| 71 | + groundTrth.append(1) |
| 72 | + else: |
| 73 | + print 'txt format is invalid' |
| 74 | + except IOError: |
| 75 | + print txt_yes, 'is not exits' |
| 76 | + else: |
| 77 | + print 'succeed' |
| 78 | + |
| 79 | + try: |
| 80 | + fileObj = open(txt_no) |
| 81 | + for line in fileObj.readlines(): |
| 82 | + curLine = line.strip().split() |
| 83 | + if len(curLine) == 1: |
| 84 | + score = float(curLine[0]) |
| 85 | + wholelist.append(score) |
| 86 | + dismatchList.append(score) |
| 87 | + groundTrth.append(0) |
| 88 | + else: |
| 89 | + print 'txt format is invalid' |
| 90 | + except IOError: |
| 91 | + print txt_no, 'is not exits' |
| 92 | + else: |
| 93 | + print 'succeed' |
| 94 | + return wholelist, groundTrth,matchList,dismatchList |
| 95 | + |
| 96 | + |
| 97 | +if __name__ == '__main__': |
| 98 | + |
| 99 | + # the first param is rootdir |
| 100 | + # the second param is thresh value |
| 101 | + txt_yes = r'D:\LWF_Yes.txt' |
| 102 | + txt_no = r'D:\LWF_No.txt' |
| 103 | + falseNum = 0 |
| 104 | + wholelist,groundtruth, matchList,dismatchList = GetData(txt_yes, txt_no) |
| 105 | + print len(wholelist) |
| 106 | + print len(groundtruth) |
| 107 | + print len(matchList) |
| 108 | + print len(dismatchList) |
| 109 | + value_thresh = plot_roc(wholelist, groundtruth) |
| 110 | + for i in matchList: |
| 111 | + if i < value_thresh: |
| 112 | + falseNum += 1 |
| 113 | + for i in dismatchList: |
| 114 | + if i > value_thresh: |
| 115 | + falseNum += 1 |
| 116 | + print float(falseNum)/len(wholelist) |
| 117 | + print 1 - float(falseNum)/len(wholelist) |
| 118 | + print falseNum |
| 119 | + |
| 120 | + |
0 commit comments