-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_train.py
executable file
·123 lines (102 loc) · 4.86 KB
/
run_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import print_function
from __future__ import division
import click
import json
import os
import numpy as np
from sklearn.utils import class_weight
from keras.optimizers import Adam
from keras.utils import to_categorical
from models.DRUNet32f import get_model
from run_test import get_eval_metrics
from tools.augmentation import augmentation
from metrics import weighted_categorical_crossentropy
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
@click.command()
@click.argument('train_imgs_np_file', type=click.STRING)
@click.argument('train_masks_np_file', type=click.STRING)
@click.argument('output_weights_file', type=click.STRING)
@click.option('--pretrained_model', type=click.STRING, default='', help='path to the pretrained model')
@click.option('--use_augmentation', type=click.BOOL, default=False, help='use data augmentation or not')
@click.option('--use_weighted_crossentropy', type=click.BOOL, default=False,
help='use weighting of classes according to inbalance or not')
@click.option('--test_imgs_np_file', type=click.STRING, default='', help='path to the numpy file of test image')
@click.option('--test_masks_np_file', type=click.STRING, default='', help='path to the numpy file of the test image')
@click.option('--output_test_eval', type=click.STRING, default='',
help='path to save results on test case evaluated per epoch of training')
def main(train_imgs_np_file, train_masks_np_file, output_weights_file, pretrained_model='',
use_augmentation=False, use_weighted_crossentropy=False,
test_imgs_np_file='', test_masks_np_file='', output_test_eval=''):
assert (test_imgs_np_file != '' and test_masks_np_file != '') or \
(test_imgs_np_file == '' and test_masks_np_file == ''), \
'Both test image file and test mask file must be given'
num_classes = 11
if not use_augmentation:
total_epochs = 1000
else:
total_epochs = 500
batch_size = 32
learn_rate = 2e-4
eval_per_epoch = (test_imgs_np_file != '' and test_masks_np_file != '')
if eval_per_epoch:
test_imgs = np.load(test_imgs_np_file)
test_masks = np.load(test_masks_np_file)
train_imgs = np.load(train_imgs_np_file)
train_masks = np.load(train_masks_np_file)
if use_weighted_crossentropy:
class_weights = class_weight.compute_class_weight('balanced', np.unique(train_masks),
train_masks.flatten())
channels_num = train_imgs.shape[-1]
img_shape = (train_imgs.shape[1], train_imgs.shape[2], channels_num)
model = get_model(img_shape=img_shape, num_classes=num_classes)
if pretrained_model != '':
assert os.path.isfile(pretrained_model)
model.load_weights(pretrained_model)
if use_augmentation:
samples_num = train_imgs.shape[0]
images_aug = np.zeros(train_imgs.shape, dtype=np.float32)
masks_aug = np.zeros(train_masks.shape, dtype=np.float32)
for i in range(samples_num):
images_aug[i], masks_aug[i] = augmentation(train_imgs[i], train_masks[i])
train_imgs = np.concatenate((train_imgs, images_aug), axis=0)
train_masks = np.concatenate((train_masks, masks_aug), axis=0)
train_masks_cat = to_categorical(train_masks, num_classes)
if use_weighted_crossentropy:
model.compile(optimizer=Adam(lr=(learn_rate)), loss=weighted_categorical_crossentropy(class_weights))
else:
model.compile(optimizer=Adam(lr=(learn_rate)), loss='categorical_crossentropy')
current_epoch = 1
history = {}
history['dsc'] = []
history['h95'] = []
history['vs'] = []
while current_epoch <= total_epochs:
print('Epoch', str(current_epoch), '/', str(total_epochs))
model.fit(train_imgs, train_masks_cat, batch_size=batch_size, epochs=1, verbose=True, shuffle=True)
if eval_per_epoch and current_epoch % 100 == 0:
model.save_weights(output_weights_file)
pred_masks = model.predict(test_imgs)
pred_masks = pred_masks.argmax(axis=3)
dsc, h95, vs = get_eval_metrics(test_masks[:, :, :, 0], pred_masks)
history['dsc'].append(dsc)
history['h95'].append(h95)
history['vs'].append(vs)
print(dsc)
print(h95)
print(vs)
if output_test_eval != '':
with open(output_test_eval, 'w+') as outfile:
json.dump(history, outfile)
current_epoch += 1
model.save_weights(output_weights_file)
# pred_masks = model.predict(train_imgs)
# pred_masks = pred_masks.argmax(axis=3)
# dsc, h95, vs = get_eval_metrics(train_masks[:, :, :, 0], pred_masks)
# print(dsc)
# print(h95)
# print(vs)
if output_test_eval != '':
with open(output_test_eval, 'w+') as outfile:
json.dump(history, outfile)
if __name__ == "__main__":
main()