-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
executable file
·360 lines (270 loc) · 15.4 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
import os
import time
import cv2
import numpy as np
import tensorflow as tf
import pydensecrf.densecrf as dcrf
import vgg
from dataset import inputs
from pydensecrf.utils import (create_pairwise_bilateral,
create_pairwise_gaussian, unary_from_softmax)
from utils import (bilinear_upsample_weights, grayscale_to_voc_impl)
import logging
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s', level=logging.DEBUG)
def parse_args(check=True):
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint_path', type=str)
parser.add_argument('--output_dir', type=str)
parser.add_argument('--dataset_train', type=str)
parser.add_argument('--dataset_val', type=str)
parser.add_argument('--batch_size', type=int, default=16)
parser.add_argument('--max_steps', type=int, default=1500)
parser.add_argument('--learning_rate', type=float, default=1e-4)
FLAGS, unparsed = parser.parse_known_args()
return FLAGS, unparsed
FLAGS, unparsed = parse_args()
slim = tf.contrib.slim
tf.reset_default_graph()
is_training_placeholder = tf.placeholder(tf.bool)
batch_size = FLAGS.batch_size
image_tensor_train, orig_img_tensor_train, annotation_tensor_train = inputs(FLAGS.dataset_train, train=True, batch_size=batch_size, num_epochs=1e4)
image_tensor_val, orig_img_tensor_val, annotation_tensor_val = inputs(FLAGS.dataset_val, train=False, num_epochs=1e4)
image_tensor, orig_img_tensor, annotation_tensor = tf.cond(is_training_placeholder,
true_fn=lambda: (image_tensor_train, orig_img_tensor_train, annotation_tensor_train),
false_fn=lambda: (image_tensor_val, orig_img_tensor_val, annotation_tensor_val))
feed_dict_to_use = {is_training_placeholder: True}
upsample_factor = 8
number_of_classes = 21
log_folder = os.path.join(FLAGS.output_dir, 'train')
vgg_checkpoint_path = FLAGS.checkpoint_path
# Creates a variable to hold the global_step.
global_step = tf.Variable(0, trainable=False, name='global_step', dtype=tf.int64)
# Define the model that we want to use -- specify to use only two classes at the last layer
with slim.arg_scope(vgg.vgg_arg_scope()):
logits, end_points = vgg.vgg_16(image_tensor,
num_classes=number_of_classes,
is_training=is_training_placeholder,
spatial_squeeze=False,
fc_conv_padding='SAME')
downsampled_logits_shape = tf.shape(logits)
img_shape = tf.shape(image_tensor)
# Calculate the ouput size of the upsampled tensor
# The shape should be batch_size X width X height X num_classes
upsampled_logits_shape = tf.stack([
downsampled_logits_shape[0],
img_shape[1],
img_shape[2],
downsampled_logits_shape[3]
])
# 获取vgg16中pool4输出的feature map(在vgg16中pool4输出后将原图缩小为原来的1/16)
pool4_feature = end_points['vgg_16/pool4']
# 对pool4的输出进行1x1xnumber_of_classes(1x1x21)的卷积,不进行激活,并使用0初始化该层卷积核
with tf.variable_scope('vgg_16/fc8'):
pool4_logits_16s = slim.conv2d(pool4_feature, number_of_classes, [1, 1],
activation_fn=None,
weights_initializer=tf.zeros_initializer,
scope='conv_pool4')
# Perform the upsampling(X2)
# 将最终的loigts大小由原图1/32变为1/16,以便与pool4的logits合并
# 反卷积长宽步长为2,即长宽会扩大2倍(倍数由upsample_factor指定)
upsample_factor = 2
upsample_filter_np_x2 = bilinear_upsample_weights(upsample_factor, number_of_classes)
upsample_filter_tensor_x2 = tf.Variable(upsample_filter_np_x2, name='vgg_16/fc8/t_conv_x2')
upsampled_logits_16s = tf.nn.conv2d_transpose(logits, upsample_filter_tensor_x2,
output_shape=tf.shape(pool4_logits_16s),
strides=[1, upsample_factor, upsample_factor, 1],
padding='SAME')
# Combine pool4_logits_16s with upsampled_logits_16s
upsampled_logits_16s = upsampled_logits_16s + pool4_logits_16s
# 获取vgg16中pool3输出的feauture map(在vgg16中pool3输出后将原图缩小为原来的1/8)
pool3_feature = end_points['vgg_16/pool3']
# 对pool3的输出进行1x1xnumber_of_classses(1x1x21)的卷积,不进行激活,并使用0初始化该层卷积核
with tf.variable_scope('vgg_16/fc8'):
pool3_logits_8s = slim.conv2d(pool3_feature, number_of_classes, [1, 1],
activation_fn=None,
weights_initializer=tf.zeros_initializer,
scope='conv_pool3')
# Perform the upsampling(X2)
# 将上采样后的upsampled_logits_16s由原图1/16变为1/8,以便与pool3的logits合并
# 反卷积长宽步长为2,即长宽会扩大2倍(倍数由upsample_factor指定)
upsample_factor = 2
upsample_filter_np_x2 = bilinear_upsample_weights(upsample_factor, number_of_classes)
upsample_filter_tensor_x2 = tf.Variable(upsample_filter_np_x2, name='vgg_16/fc8/t_conv_x2_x2')
upsampled_logits_8s = tf.nn.conv2d_transpose(upsampled_logits_16s, upsample_filter_tensor_x2,
output_shape=tf.shape(pool3_logits_8s),
strides=[1, upsample_factor, upsample_factor, 1],
padding='SAME')
# Combine pool3_logits with upsampled_logits
upsampled_logits = upsampled_logits_8s + pool3_logits_8s
# Perform the upsampling(X8)
# 将上采样后的upsampled_logits由原图1/8变为与原图等大
# 反卷积长宽步长为8,即长宽会扩大8倍(倍数由upsample_factor指定)
upsample_factor = 8
upsample_filter_np_x8 = bilinear_upsample_weights(upsample_factor, number_of_classes)
upsample_filter_tensor_x8 = tf.Variable(upsample_filter_np_x8, name='vgg_16/fc8/t_conv_x8')
upsampled_logits = tf.nn.conv2d_transpose(upsampled_logits, upsample_filter_tensor_x8,
output_shape=upsampled_logits_shape,
strides=[1, upsample_factor, upsample_factor, 1],
padding='SAME')
# onehot编码
lbl_onehot = tf.one_hot(annotation_tensor, number_of_classes)
# 通过上采样的结果与labels计算交叉熵
cross_entropies = tf.nn.softmax_cross_entropy_with_logits(logits=upsampled_logits, labels=lbl_onehot)
# 计算平均损失值
cross_entropy_loss = tf.reduce_mean(tf.reduce_sum(cross_entropies, axis=-1))
# Tensor to get the final prediction for each pixel -- pay
# attention that we don't need softmax in this case because
# we only need the final decision. If we also need the respective
# probabilities we will have to apply softmax.
pred = tf.argmax(upsampled_logits, axis=3)
probabilities = tf.nn.softmax(upsampled_logits)
# Here we define an optimizer and put all the variables
# that will be created under a namespace of 'adam_vars'.
# This is done so that we can easily access them later.
# Those variables are used by adam optimizer and are not
# related to variables of the vgg model.
# We also retrieve gradient Tensors for each of our variables
# This way we can later visualize them in tensorboard.
# optimizer.compute_gradients and optimizer.apply_gradients
# is equivalent to running:
# train_step = tf.train.AdamOptimizer(learning_rate=0.0001).minimize(cross_entropy_loss)
with tf.variable_scope("adam_vars"):
optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
gradients = optimizer.compute_gradients(loss=cross_entropy_loss)
for grad_var_pair in gradients:
current_variable = grad_var_pair[1]
current_gradient = grad_var_pair[0]
# Relace some characters from the original variable name
# tensorboard doesn't accept ':' symbol
gradient_name_to_save = current_variable.name.replace(":", "_")
# Let's get histogram of gradients for each layer and
# visualize them later in tensorboard
tf.summary.histogram(gradient_name_to_save, current_gradient)
train_step = optimizer.apply_gradients(grads_and_vars=gradients, global_step=global_step)
# Now we define a function that will load the weights from VGG checkpoint
# into our variables when we call it. We exclude the weights from the last layer
# which is responsible for class predictions. We do this because
# we will have different number of classes to predict and we can't
# use the old ones as an initialization.
vgg_except_fc8_weights = slim.get_variables_to_restore(exclude=['vgg_16/fc8', 'adam_vars'])
# Here we get variables that belong to the last layer of network.
# As we saw, the number of classes that VGG was originally trained on
# is different from ours -- in our case it is only 2 classes.
vgg_fc8_weights = slim.get_variables_to_restore(include=['vgg_16/fc8'])
adam_optimizer_variables = slim.get_variables_to_restore(include=['adam_vars'])
# Add summary op for the loss -- to be able to see it in
# tensorboard.
tf.summary.scalar('cross_entropy_loss', cross_entropy_loss)
# Put all summary ops into one op. Produces string when
# you run it.
merged_summary_op = tf.summary.merge_all()
# Create the summary writer -- to write all the logs
# into a specified file. This file can be later read
# by tensorboard.
summary_string_writer = tf.summary.FileWriter(log_folder)
# Create the log folder if doesn't exist yet
if not os.path.exists(log_folder):
os.makedirs(log_folder)
checkpoint_path = tf.train.latest_checkpoint(log_folder)
continue_train = False
if checkpoint_path:
tf.logging.info(
'Ignoring --checkpoint_path because a checkpoint already exists in %s'
% log_folder)
variables_to_restore = slim.get_model_variables()
continue_train = True
else:
# Create an OP that performs the initialization of
# values of variables to the values from VGG.
read_vgg_weights_except_fc8_func = slim.assign_from_checkpoint_fn(
vgg_checkpoint_path,
vgg_except_fc8_weights)
# Initializer for new fc8 weights -- for two classes.
vgg_fc8_weights_initializer = tf.variables_initializer(vgg_fc8_weights)
# Initializer for adam variables
optimization_variables_initializer = tf.variables_initializer(adam_optimizer_variables)
sess_config = tf.ConfigProto()
sess_config.gpu_options.allow_growth = True
sess = tf.Session(config=sess_config)
init_op = tf.global_variables_initializer()
init_local_op = tf.local_variables_initializer()
saver = tf.train.Saver(max_to_keep=5)
def perform_crf(image, probabilities):
image = image.squeeze()
softmax = probabilities.squeeze().transpose((2, 0, 1))
# The input should be the negative of the logarithm of probability values
# Look up the definition of the softmax_to_unary for more information
unary = unary_from_softmax(softmax)
# The inputs should be C-continious -- we are using Cython wrapper
unary = np.ascontiguousarray(unary)
d = dcrf.DenseCRF(image.shape[0] * image.shape[1], number_of_classes)
d.setUnaryEnergy(unary)
# This potential penalizes small pieces of segmentation that are
# spatially isolated -- enforces more spatially consistent segmentations
feats = create_pairwise_gaussian(sdims=(10, 10), shape=image.shape[:2])
d.addPairwiseEnergy(feats, compat=3,
kernel=dcrf.DIAG_KERNEL,
normalization=dcrf.NORMALIZE_SYMMETRIC)
# This creates the color-dependent features --
# because the segmentation that we get from CNN are too coarse
# and we can use local color features to refine them
feats = create_pairwise_bilateral(sdims=(50, 50), schan=(20, 20, 20),
img=image, chdim=2)
d.addPairwiseEnergy(feats, compat=10,
kernel=dcrf.DIAG_KERNEL,
normalization=dcrf.NORMALIZE_SYMMETRIC)
Q = d.inference(5)
res = np.argmax(Q, axis=0).reshape((image.shape[0], image.shape[1]))
return res
with sess:
# Run the initializers.
sess.run(init_op)
sess.run(init_local_op)
if continue_train:
saver.restore(sess, checkpoint_path)
logging.debug('checkpoint restored from [{0}]'.format(checkpoint_path))
else:
sess.run(vgg_fc8_weights_initializer)
sess.run(optimization_variables_initializer)
read_vgg_weights_except_fc8_func(sess)
logging.debug('value initialized...')
# start data reader
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
start = time.time()
for i in range(FLAGS.max_steps):
feed_dict_to_use[is_training_placeholder] = True
gs, _ = sess.run([global_step, train_step], feed_dict=feed_dict_to_use)
if gs % 10 == 0:
gs, loss, summary_string = sess.run([global_step, cross_entropy_loss, merged_summary_op], feed_dict=feed_dict_to_use)
logging.debug("step {0} Current Loss: {1} ".format(gs, loss))
end = time.time()
logging.debug("[{0:.2f}] imgs/s".format(10 * batch_size / (end - start)))
start = end
summary_string_writer.add_summary(summary_string, i)
if gs % 100 == 0:
save_path = saver.save(sess, os.path.join(log_folder, "model.ckpt"), global_step=gs)
logging.debug("Model saved in file: %s" % save_path)
if gs % 200 == 0:
eval_folder = os.path.join(FLAGS.output_dir, 'eval')
if not os.path.exists(eval_folder):
os.makedirs(eval_folder)
logging.debug("validation generated at step [{0}]".format(gs))
feed_dict_to_use[is_training_placeholder] = False
val_pred, val_orig_image, val_annot, val_poss = sess.run([pred, orig_img_tensor, annotation_tensor, probabilities],
feed_dict=feed_dict_to_use)
cv2.imwrite(os.path.join(eval_folder, 'val_{0}_img.jpg'.format(gs)), cv2.cvtColor(np.squeeze(val_orig_image), cv2.COLOR_RGB2BGR))
cv2.imwrite(os.path.join(eval_folder, 'val_{0}_annotation.jpg'.format(gs)), cv2.cvtColor(grayscale_to_voc_impl(np.squeeze(val_annot)), cv2.COLOR_RGB2BGR))
cv2.imwrite(os.path.join(eval_folder, 'val_{0}_prediction.jpg'.format(gs)), cv2.cvtColor(grayscale_to_voc_impl(np.squeeze(val_pred)), cv2.COLOR_RGB2BGR))
crf_ed = perform_crf(val_orig_image, val_poss)
cv2.imwrite(os.path.join(FLAGS.output_dir, 'eval', 'val_{0}_prediction_crfed.jpg'.format(gs)), cv2.cvtColor(grayscale_to_voc_impl(np.squeeze(crf_ed)), cv2.COLOR_RGB2BGR))
overlay = cv2.addWeighted(cv2.cvtColor(np.squeeze(val_orig_image), cv2.COLOR_RGB2BGR), 1, cv2.cvtColor(grayscale_to_voc_impl(np.squeeze(crf_ed)), cv2.COLOR_RGB2BGR), 0.8, 0)
cv2.imwrite(os.path.join(FLAGS.output_dir, 'eval', 'val_{0}_overlay.jpg'.format(gs)), overlay)
coord.request_stop()
coord.join(threads)
save_path = saver.save(sess, os.path.join(log_folder, "model.ckpt"), global_step=gs)
logging.debug("Model saved in file: %s" % save_path)
summary_string_writer.close()