Skip to content

Commit 5b2e480

Browse files
committed
wip add unittests
1 parent 643d670 commit 5b2e480

File tree

3 files changed

+195
-13
lines changed

3 files changed

+195
-13
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .evaluate import evaluate_file
1+
from .evaluate import evaluate_file, evaluate_volume
22
from .summarize import summarize_metric_dict

evaluateInstanceSegmentation/evaluate.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,19 @@ class Metrics:
3737
def __init__(self, fn):
3838
self.metricsDict = {}
3939
self.metricsArray = []
40-
self.fn = fn
41-
self.outFl = open(self.fn+".txt", 'w')
40+
if fn is not None:
41+
self.fn = fn
42+
self.outFl = open(self.fn + '.txt', 'w')
43+
else:
44+
self.fn = None
45+
self.outFl = None
4246

4347
def save(self):
44-
self.outFl.close()
45-
logger.info("saving %s", self.fn)
46-
tomlFl = open(self.fn+".toml", 'w')
47-
toml.dump(self.metricsDict, tomlFl)
48+
if self.outFl is not None:
49+
self.outFl.close()
50+
logger.info("saving %s", self.fn)
51+
tomlFl = open(self.fn+".toml", 'w')
52+
toml.dump(self.metricsDict, tomlFl)
4853

4954
def addTable(self, name, dct=None):
5055
levels = name.split(".")
@@ -67,8 +72,9 @@ def getTable(self, name, dct=None):
6772
return self.getTable(name, dct=dct[levels[0]])
6873

6974
def addMetric(self, table, name, value):
70-
as_str = "{}: {}".format(name, value)
71-
self.outFl.write(as_str+"\n")
75+
if self.outFl is not None:
76+
as_str = "{}: {}".format(name, value)
77+
self.outFl.write(as_str+"\n")
7278
self.metricsArray.append(value)
7379
tbl = self.getTable(table)
7480
tbl[name] = value
@@ -642,7 +648,7 @@ def evaluate_volume(gt_labels, pred_labels, outFn,
642648
ap = precision * recall
643649
aps.append(ap)
644650
if (precision + recall) > 0:
645-
fscore = (2. * precision * recall) / max(1, precision + recall)
651+
fscore = (2. * precision * recall) / (precision + recall)
646652
else:
647653
fscore = 0.0
648654
fscores.append(fscore)
@@ -696,7 +702,7 @@ def evaluate_volume(gt_labels, pred_labels, outFn,
696702
for gt_i in np.arange(1, num_gt_labels + 1):
697703
if gt_i in max_gt_ind_unique:
698704
pred_union = np.zeros(
699-
pred_labels_rel.shape[1:],
705+
pred_labels_rel.shape[1:],
700706
dtype=pred_labels_rel.dtype)
701707
for pred_i in np.arange(num_pred_labels + 1)[max_gt_ind == gt_i]:
702708
mask = pred_labels_rel[pred_i - 1] > 0
@@ -708,13 +714,11 @@ def evaluate_volume(gt_labels, pred_labels, outFn,
708714
else:
709715
# if gt has overlapping instances, but not prediction
710716
if len(gt_labels_rel.shape) > len(pred_labels_rel.shape):
711-
print("gt cov for overlapping gt, but not pred")
712717
for i in range(1, recallMat.shape[0]):
713718
gt_cov.append(np.sum(recallMat[i, max_gt_ind==i]))
714719
# if none has overlapping instances
715720
else:
716721
gt_cov = np.sum(recallMat[1:, 1:], axis=1)
717-
print("gt cov: ", gt_cov)
718722
gt_skel_coverage = np.mean(gt_cov)
719723
metrics.addMetric(tblNameGen, "gt_skel_coverage", gt_cov)
720724
metrics.addMetric(tblNameGen, "avg_gt_skel_coverage", gt_skel_coverage)

tests/test_metrics.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import unittest
2+
import os
3+
import numpy as np
4+
from skimage.morphology import skeletonize_3d
5+
6+
from evaluateInstanceSegmentation import evaluate_file, evaluate_volume
7+
8+
# work in progress
9+
class TestMetrics(unittest.TestCase):
10+
# how to name stuff, convex and tubular instead of nuclei and neuron?
11+
def set_expected(self, num_gt, num_pred, avg_gt_cov, avg_tp_cov, avg_f1_cov, gt_cov, tp_cov,
12+
avAP59, avAP19, avFscore59, avFscore19,
13+
tp_0_5, fp_0_5, fn_0_5, fs_0_5, fm_0_5):
14+
# define metric dictionary structure for expected values
15+
expected = {
16+
"general": {},
17+
"confusion_matrix": {"th_0_5":{}}
18+
}
19+
# set values
20+
expected["general"]["Num GT"] = num_gt
21+
expected["general"]["Num Pred"] = num_pred
22+
expected["general"]["avg_gt_skel_coverage"] = avg_gt_cov
23+
expected["general"]["avg_tp_skel_coverage"] = avg_tp_cov
24+
expected["general"]["avg_f1_cov_score"] = avg_f1_cov
25+
expected["general"]["gt_skel_coverage"] = gt_cov
26+
expected["general"]["tp_skel_coverage"] = tp_cov
27+
expected["confusion_matrix"]["avAP59"] = avAP59
28+
expected["confusion_matrix"]["avAP19"] = avAP19
29+
expected["confusion_matrix"]["avFscore59"] = avFscore59
30+
expected["confusion_matrix"]["avFscore19"] = avFscore19
31+
expected["confusion_matrix"]["th_0_5"]["AP_TP"] = tp_0_5
32+
expected["confusion_matrix"]["th_0_5"]["AP_FP"] = fp_0_5
33+
expected["confusion_matrix"]["th_0_5"]["AP_FN"] = fn_0_5
34+
expected["confusion_matrix"]["th_0_5"]["false_split"] = fs_0_5
35+
expected["confusion_matrix"]["th_0_5"]["false_merge"] = fm_0_5
36+
37+
return expected
38+
39+
40+
def check_results(self, results, expected):
41+
# check general
42+
res = results["general"]
43+
exp = expected["general"]
44+
self.assertEqual(res["Num GT"], exp["Num GT"])
45+
self.assertEqual(res["Num Pred"], exp["Num Pred"])
46+
self.assertEqual(round(res["avg_gt_skel_coverage"], 4),
47+
round(exp["avg_gt_skel_coverage"], 4))
48+
self.assertEqual(round(res["avg_tp_skel_coverage"], 4),
49+
round(exp["avg_tp_skel_coverage"], 4))
50+
self.assertEqual(res["avg_f1_cov_score"], exp["avg_f1_cov_score"])
51+
self.assertListEqual(list(np.round(res["gt_skel_coverage"], 4)),
52+
list(np.round(exp["gt_skel_coverage"], 4)))
53+
self.assertListEqual(list(res["tp_skel_coverage"]), list(exp["tp_skel_coverage"]))
54+
# check confusion table
55+
res = results["confusion_matrix"]
56+
exp = expected["confusion_matrix"]
57+
self.assertEqual(res["avAP59"], exp["avAP59"])
58+
self.assertEqual(res["avAP19"], exp["avAP19"])
59+
self.assertEqual(res["avFscore59"], exp["avFscore59"])
60+
self.assertEqual(res["avFscore19"], exp["avFscore19"])
61+
# check error quantities for confusion table at threshold 0.5
62+
res = results["confusion_matrix"]["th_0_5"]
63+
exp = expected["confusion_matrix"]["th_0_5"]
64+
self.assertEqual(res["AP_TP"], exp["AP_TP"])
65+
self.assertEqual(res["AP_FP"], exp["AP_FP"])
66+
self.assertEqual(res["AP_FN"], exp["AP_FN"])
67+
self.assertEqual(res["false_split"], exp["false_split"])
68+
self.assertEqual(res["false_merge"], exp["false_merge"])
69+
70+
71+
def run_test_case(self, config, gt, pred, expected):
72+
73+
result_dict = evaluate_volume(gt, pred, config["outFn"],
74+
config["localization_criterion"], config["assignment_strategy"],
75+
config["evaluate_false_labels"], config["unique_false_labels"],
76+
config["add_general_metrics"],
77+
config["visualize"], config["visualize_type"],
78+
config["overlapping_inst"], config["partly"])
79+
print(result_dict)
80+
self.check_results(result_dict, expected)
81+
82+
def test_2d_nuclei(self):
83+
print("todo: test 2d nuclei")
84+
85+
86+
def test_3d_nuclei(self):
87+
print("todo: test 2d nuclei")
88+
89+
90+
def test_3d_neuron(self):
91+
gt = np.zeros((2, 30, 30, 30), dtype=np.int32)
92+
gt[0, 14:17, 14:17, 5:25] = 1
93+
gt[1, 14:17, 5:25, 14:17] = 2
94+
95+
print(np.sum(gt==1), np.sum(gt==2), np.sum(np.sum(gt>0, axis=0) > 1))
96+
97+
# set parameters
98+
config = {
99+
"outFn": None,
100+
"localization_criterion": "cldice",
101+
"assignment_strategy": "greedy",
102+
"add_general_metrics": ["avg_gt_skel_coverage",
103+
"avg_f1_cov_score", "avg_tp_skel_coverage"],
104+
"evaluate_false_labels": True,
105+
"unique_false_labels": False,
106+
"visualize": False,
107+
"visualize_type": None,
108+
"overlapping_inst": True,
109+
"partly": False
110+
}
111+
112+
# test case 1: perfect segmentation
113+
# (1.1) pred + gt overlaps
114+
pred = gt.copy()
115+
self.run_test_case(config, gt, pred,
116+
self.set_expected(2, 2, 1.0, 1.0, 1.0, [1.0, 1.0], [1.0, 1.0],
117+
1.0, 1.0, 1.0, 1.0,
118+
2, 0, 0, 0, 0)
119+
)
120+
121+
# (1.2) gt overlaps
122+
pred = np.max(pred, axis=0)
123+
config["overlapping_inst"] = False
124+
# set expected values
125+
gt_cov = np.array([15/18.0, 1.0], dtype=np.float32)
126+
avg_gt_cov = np.mean(gt_cov)
127+
avg_f1_cov = np.mean([avg_gt_cov, 1.0])
128+
self.run_test_case(config, gt, pred,
129+
self.set_expected(2, 2, avg_gt_cov, avg_gt_cov, avg_f1_cov,
130+
gt_cov, gt_cov,
131+
0.925, 1.0, 0.95, 1.0,
132+
2, 0, 0, 0, 0)
133+
)
134+
135+
# (1.3) no overlaps
136+
gt = np.max(gt, axis=0)
137+
self.run_test_case(config, gt, pred,
138+
self.set_expected(2, 2, 1.0, 1.0, 1.0, [1.0, 1.0], [1.0, 1.0],
139+
1.0, 1.0, 1.0, 1.0,
140+
2, 0, 0, 0, 0)
141+
)
142+
143+
# test case 2: erroneous segmentation
144+
gt = np.zeros((3, 30, 30, 30), dtype=np.int32)
145+
gt[0, 14:17, 14:17, 5:25] = 1
146+
gt[1, 14:17, 5:25, 14:17] = 2
147+
gt[2, 5:8, 5:25, 20:23] = 3
148+
149+
pred = np.zeros((5, 30, 30, 30), dtype=np.int32)
150+
pred[0, 14:17, 14:17, 5:25] = 1
151+
pred[0, 14:17, 5:20, 14:17] = 1
152+
pred[1, 25:30, 25:30, 25:30] = 2
153+
pred[2, 5:8, 5:11, 20:23] = 3
154+
pred[3, 5:8, 19:22, 20:23] = 4
155+
pred[4, 1:5, 1:5, 1:5] = 5
156+
157+
# (2.1) pred + gt overlaps
158+
config["overlapping_inst"] = False
159+
gt_cov = np.array([1.0, 14/18.0, 8/18.0], dtype=np.float32)
160+
avg_gt_cov = np.mean(gt_cov)
161+
tp_cov = np.array([1.0, 14/18.0], dtype=np.float32)
162+
avg_tp_cov = np.mean(tp_cov)
163+
ap19 = np.mean([4/15.0,] * 4 + [1/15.0,] * 3 + [0.0, 0.0])
164+
ap59 = np.mean([1/15.0,] * 6 + [0.0,] * 4)
165+
fscore19 = np.mean([0.5,] * 4 + [0.25,] * 3 + [0.0, 0.0])
166+
fscore59 = np.mean([0.25,] * 6 + [0.0,] * 4)
167+
avg_f1_cov = np.mean([avg_gt_cov, fscore19])
168+
self.run_test_case(config, gt, pred,
169+
self.set_expected(3, 5, avg_gt_cov, avg_tp_cov, avg_f1_cov,
170+
gt_cov, tp_cov,
171+
ap59, ap19, fscore59, fscore19,
172+
1, 4, 2, 2, 1)
173+
)
174+
175+
176+
if __name__ == '__main__':
177+
unittest.main()
178+

0 commit comments

Comments
 (0)