-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathcomp_f1_diag_proc.py
executable file
·145 lines (117 loc) · 5.86 KB
/
comp_f1_diag_proc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue May 12 09:45:01 2020
@author: antonio
"""
import pandas as pd
import argparse
import warnings
###### 0. Load valid codes lists: ######
def read_gs(gs_path):
gs_data = pd.read_csv(gs_path, sep="\t", names=['clinical_case', 'code'],
dtype={'clinical_case': object, 'code':object})
gs_data.code = gs_data.code.str.lower()
return gs_data
def read_run(pred_path, valid_codes):
run_data = pd.read_csv(pred_path, sep="\t", names=['clinical_case', 'code'],
dtype={'clinical_case': object, 'code':object})
run_data.code = run_data.code.str.lower()
run_data = run_data[run_data['code'].isin(valid_codes)]
if (run_data.shape[0] == 0):
warnings.warn('None of the predicted codes are considered valid codes')
return run_data
def calculate_metrics(df_gs, df_pred):
Pred_Pos_per_cc = df_pred.drop_duplicates(subset=['clinical_case',
"code"]).groupby("clinical_case")["code"].count()
Pred_Pos = df_pred.drop_duplicates(subset=['clinical_case', "code"]).shape[0]
# Gold Standard Positives:
GS_Pos_per_cc = df_gs.drop_duplicates(subset=['clinical_case',
"code"]).groupby("clinical_case")["code"].count()
GS_Pos = df_gs.drop_duplicates(subset=['clinical_case', "code"]).shape[0]
cc = set(df_gs.clinical_case.tolist())
TP_per_cc = pd.Series(dtype=float)
for c in cc:
pred = set(df_pred.loc[df_pred['clinical_case']==c,'code'].values)
gs = set(df_gs.loc[df_gs['clinical_case']==c,'code'].values)
TP_per_cc[c] = len(pred.intersection(gs))
TP = sum(TP_per_cc.values)
# Calculate Final Metrics:
P_per_cc = TP_per_cc / Pred_Pos_per_cc
P = TP / Pred_Pos
R_per_cc = TP_per_cc / GS_Pos_per_cc
R = TP / GS_Pos
F1_per_cc = (2 * P_per_cc * R_per_cc) / (P_per_cc + R_per_cc)
if (P+R) == 0:
F1 = 0
warnings.warn('Global F1 score automatically set to zero to avoid division by zero')
return P_per_cc, P, R_per_cc, R, F1_per_cc, F1
F1 = (2 * P * R) / (P + R)
return P_per_cc, P, R_per_cc, R, F1_per_cc, F1
def parse_arguments():
'''
DESCRIPTION: Parse command line arguments
'''
parser = argparse.ArgumentParser(description='process user given parameters')
parser.add_argument("-g", "--gs_path", required = True, dest = "gs_path",
help = "path to GS file")
parser.add_argument("-p", "--pred_path", required = True, dest = "pred_path",
help = "path to predictions file")
parser.add_argument("-c", "--valid_codes_path", required = True,
dest = "codes_path", help = "path to valid codes TSV")
args = parser.parse_args()
gs_path = args.gs_path
pred_path = args.pred_path
codes_path = args.codes_path
return gs_path, pred_path, codes_path
if __name__ == '__main__':
gs_path, pred_path, codes_path = parse_arguments()
###### 0. Load valid codes lists: ######
valid_codes = set(pd.read_csv(codes_path, sep='\t', header=None,
usecols=[0])[0].tolist())
valid_codes = set([x.lower() for x in valid_codes])
###### 1. Load GS and Predictions ######
df_gs = read_gs(gs_path)
df_run = read_run(pred_path, valid_codes)
###### 2. Calculate score ######
P_per_cc, P, R_per_cc, R, F1_per_cc, F1 = calculate_metrics(df_gs, df_run)
###### 3. Show results ######
print('\n-----------------------------------------------------')
print('Clinical case name\t\t\tPrecision')
print('-----------------------------------------------------')
for index, val in P_per_cc.items():
print(str(index) + '\t\t' + str(round(val, 3)))
print('-----------------------------------------------------')
if any(P_per_cc.isna()):
warnings.warn('Some documents do not have predicted codes, ' +
'document-wise Precision not computed for them.')
print('\nMicro-average precision = {}\n'.format(round(P, 3)))
print('\n-----------------------------------------------------')
print('Clinical case name\t\t\tRecall')
print('-----------------------------------------------------')
for index, val in R_per_cc.items():
print(str(index) + '\t\t' + str(round(val, 3)))
print('-----------------------------------------------------')
if any(R_per_cc.isna()):
warnings.warn('Some documents do not have Gold Standard codes, ' +
'document-wise Recall not computed for them.')
print('\nMicro-average recall = {}\n'.format(round(R, 3)))
print('\n-----------------------------------------------------')
print('Clinical case name\t\t\tF-score')
print('-----------------------------------------------------')
for index, val in F1_per_cc.items():
print(str(index) + '\t\t' + str(round(val, 3)))
print('-----------------------------------------------------')
if any(P_per_cc.isna()):
warnings.warn('Some documents do not have predicted codes, ' +
'document-wise F-score not computed for them.')
if any(R_per_cc.isna()):
warnings.warn('Some documents do not have Gold Standard codes, ' +
'document-wise F-score not computed for them.')
print('\nMicro-average F-score = {}\n'.format(round(F1, 3)))
print('\n__________________________________________________________')
print('\nMICRO-AVERAGE STATISTICS:')
print('\nMicro-average precision = {}'.format(round(P, 3)))
print('\nMicro-average recall = {}'.format(round(R, 3)))
print('\nMicro-average F-score = {}\n'.format(round(F1, 3)))
print('\n{}|{}|{}'.format(round(P, 3), round(R, 3), round(F1, 3)))