-
Notifications
You must be signed in to change notification settings - Fork 0
/
store.py
89 lines (71 loc) · 2.65 KB
/
store.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import argparse
import torch
import numpy as np
import matplotlib.pyplot as plt
from metrics import Metrics
KEY_WEIGHTS = 'weights'
KEY_OPTIMS = 'optims'
KEY_METRICS = 'metrics'
class Store():
def __init__(self):
self.weights = None
self.optims = None
self.metrics = None
def load(self, path, map_location=None):
data = torch.load(path, map_location=map_location)
self.weights = data.get(KEY_WEIGHTS)
self.optims = data.get(KEY_OPTIMS)
self.metrics = data.get(KEY_METRICS)
def set_states(self, weights, optims, metrics):
assert weights
assert optims
assert metrics
self.weights = weights
self.optims = optims
self.metrics = metrics
def save(self, path):
torch.save({
KEY_WEIGHTS: self.weights,
KEY_OPTIMS: self.optims,
KEY_METRICS: self.metrics,
}, path)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('path')
args = parser.parse_args()
PATH = args.path
store = Store()
store.load(PATH, map_location='cpu')
metrics = Metrics()
metrics.load_state_dict(store.metrics)
def plot_line(plt, values, label, offset=None):
plt.plot(values, label=label)
if not offset:
return
for i, value in enumerate(values):
text = "{:.3f}".format(value)
plt.annotate(text, # this is the text
(i, value), # this is the point to label
textcoords="offset points", # how to position the text
xytext=(0, offset), # distance from text to points (x,y)
ha='center') # horizontal alignment can be left, right or center
epoch = len(metrics.get('losses'))
plt.figure(figsize=(max(epoch//1.5, 10), 10))
name = os.path.splitext(os.path.basename(PATH))[0]
plt.title(name)
# plt.title('VGG16-Upsample-nearest-768')
plot_line(plt, metrics.get('losses'), 'loss', 10)
# plot_line(plt, metrics.get('jacs'), 'IoU')
plot_line(plt, metrics.get('pjacs'), 'IoU', -10)
# plot_line(plt, metrics.get('pdices'), 'acc')
plot_line(plt, metrics.get('gsensis'), 'gland sensitivity')
plot_line(plt, metrics.get('gspecs'), 'gland specificity')
plot_line(plt, metrics.get('tsensis'), 'tumor sensitivity')
plot_line(plt, metrics.get('tspecs'), 'tumor specificity')
print(metrics.last_coef())
# plt.xticks(list(range(0, epoch)))
plt.yticks(np.arange(0, 11) / 10)
plt.grid(True)
plt.legend(bbox_to_anchor=(1.01, 1), loc='upper left', borderaxespad=0)
plt.show()