forked from CellProfiling/cyto-challenge
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsolution_checker.py
169 lines (141 loc) · 5.18 KB
/
solution_checker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""Calculate score for answers."""
import argparse
import csv
import sys
import collections
PREC = 'prec'
REC = 'rec'
TRUE_POS = 'tp'
FALSE_POS = 'fp'
FALSE_NEG = 'fn'
class ScoreError(Exception):
"""Error raised when the solution checker fails."""
pass
def parseargs():
"""Parse command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument('submitted_answer')
parser.add_argument('solution_key')
parser.add_argument('-i', '--include', action='append',
help=('Include the specified class. '
'Can be used multiple times'))
args = parser.parse_args()
return args.submitted_answer, args.solution_key, args.include
def read_key_file(f_path):
"""Read key file."""
ids = {}
try:
with open(f_path, 'r') as fil:
for line in csv.reader(fil):
id_ = line[0].strip()
classes = [x.strip() for x in line[1:]]
ids[id_] = classes
return ids
except IndexError:
print(('Expected input on format "ID, class1, ...\\n" but'
'recieved', ''.join(line)), file=sys.stderr)
raise ScoreError(
'Expected input on format "ID, class1, ...\\n" but'
'recieved {}'.format(''.join(line)))
except FileNotFoundError:
assert isinstance(f_path, str)
print('Could not find the file', f_path, file=sys.stderr)
raise ScoreError('Could not find the file {}'.format(f_path))
except AssertionError:
print('Not a valid filename', file=sys.stderr)
raise ScoreError('Not a valid filename')
def calc_precision(submitted, solution):
"""Calculate precision."""
precision = collections.defaultdict(lambda: {TRUE_POS: 0, FALSE_POS: 0})
for sub in submitted:
sub_key = submitted[sub]
sol_key = solution.get(sub)
if sol_key is None:
print(sub, 'could not be found in solution key', file=sys.stderr)
raise ScoreError(
'{} could not be found in solution key'.format(sub))
for key in sub_key:
if key in sol_key:
precision[key][TRUE_POS] += 1
else:
precision[key][FALSE_POS] += 1
for key in precision:
tp_prec = precision[key][TRUE_POS]
fp_prec = precision[key][FALSE_POS]
precision[key][PREC] = tp_prec / (tp_prec + fp_prec)
return dict(precision)
def calc_recall(submitted, solution):
"""Calculate recall."""
recall = collections.defaultdict(lambda: {TRUE_POS: 0, FALSE_NEG: 0})
for sub in submitted:
sub_key = submitted[sub]
sol_key = solution.get(sub)
if sol_key is None:
print(sub, 'could not be found in solution key', file=sys.stderr)
raise ScoreError(
'{} could not be found in solution key'.format(sub))
for key in sol_key:
if key in sub_key:
recall[key][TRUE_POS] += 1
else:
recall[key][FALSE_NEG] += 1
for key in recall:
tp_recall = recall[key][TRUE_POS]
fn_recall = recall[key][FALSE_NEG]
recall[key][REC] = tp_recall / (tp_recall + fn_recall)
return dict(recall)
def calc_f1_score(precision, recall):
"""Calculate f1 score."""
f1_score = collections.defaultdict(float)
for key in precision:
prec = precision[key][PREC]
if key in recall:
rec = recall[key][REC]
else:
rec = 0
f1_score[key] = 2 * ((prec * rec) / ((prec + rec) or 1.0))
return dict(f1_score)
def score(submitted_answer, solution_key, include_classes=None):
"""Score the submitted answer."""
submitted = read_key_file(submitted_answer)
solution = read_key_file(solution_key)
if len(submitted) != len(solution):
print('Differring number of answers and solutions', file=sys.stderr)
print('Num answers: {}, Num solutions: {}'.format(
len(submitted), len(solution), file=sys.stderr))
precision = calc_precision(submitted, solution)
recall = calc_recall(submitted, solution)
f1_score = calc_f1_score(precision, recall)
fin_f_score = 0.0
fin_r_score = 0.0
fin_p_score = 0.0
if include_classes:
keys = include_classes
else:
keys = f1_score.keys()
for key in keys:
if key not in f1_score:
print(key, 'not in the available keys', file=sys.stderr)
raise ScoreError('{} not in the available keys'.format(key))
fin_f_score += f1_score[key]
try:
fin_r_score += recall[key][REC]
except IndexError:
pass
try:
fin_p_score += precision[key][PREC]
except IndexError:
pass
fin_f_score /= len(f1_score)
fin_r_score /= len(f1_score)
fin_p_score /= len(f1_score)
print('Recall:', fin_r_score)
print('Precision:', fin_p_score)
print('F1 score:', fin_f_score)
return fin_r_score, fin_p_score, fin_f_score
def main():
"""Calculate score for answer."""
submitted_answer, solution_key, include_classes = parseargs()
score(submitted_answer, solution_key, include_classes)
if __name__ == '__main__':
main()