-
Notifications
You must be signed in to change notification settings - Fork 2
/
base.py
487 lines (401 loc) · 25 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
import argparse
import tensorflow as tf
import os
import time
from glob import glob
from collections import namedtuple
from utils import *
import tensorflow.contrib.slim as slim
tf.set_random_seed(19)
def batch_norm(x, name="batch_norm"):
return tf.contrib.layers.batch_norm(x, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope=name)
def instance_norm(input, name="instance_norm"):
with tf.variable_scope(name):
depth = input.get_shape()[3]
scale = tf.get_variable("scale", [depth], initializer=tf.random_normal_initializer(1.0, 0.02, dtype=tf.float32))
offset = tf.get_variable("offset", [depth], initializer=tf.constant_initializer(0.0))
mean, variance = tf.nn.moments(input, axes=[1,2], keep_dims=True)
epsilon = 1e-5
inv = tf.rsqrt(variance + epsilon)
normalized = (input-mean)*inv
return scale*normalized + offset
def conv2d(input_, output_dim, ks=4, s=2, stddev=0.02, padding='SAME', name="conv2d"):
with tf.variable_scope(name):
return slim.conv2d(input_, output_dim, ks, s, padding=padding, activation_fn=None,
weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
biases_initializer=None)
def deconv2d(input_, output_dim, ks=4, s=2, stddev=0.02, name="deconv2d"):
with tf.variable_scope(name):
return slim.conv2d_transpose(input_, output_dim, ks, s, padding='SAME', activation_fn=None,
weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
biases_initializer=None)
def lrelu(x, leak=0.2, name="lrelu"):
return tf.maximum(x, leak*x)
def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False):
with tf.variable_scope(scope or "Linear"):
matrix = tf.get_variable("Matrix", [input_.get_shape()[-1], output_size], tf.float32,
tf.random_normal_initializer(stddev=stddev))
bias = tf.get_variable("bias", [output_size],
initializer=tf.constant_initializer(bias_start))
if with_w:
return tf.matmul(input_, matrix) + bias, matrix, bias
else:
return tf.matmul(input_, matrix) + bias
def discriminator(image, options, reuse=False, name="discriminator"):
with tf.variable_scope(name):
# image is 256 x 256 x input_c_dim
if reuse:
tf.get_variable_scope().reuse_variables()
else:
assert tf.get_variable_scope().reuse is False
h0 = lrelu(conv2d(image, options.df_dim, name='d_h0_conv'))
# h0 is (128 x 128 x self.df_dim)
h1 = lrelu(instance_norm(conv2d(h0, options.df_dim*2, name='d_h1_conv'), 'd_bn1'))
# h1 is (64 x 64 x self.df_dim*2)
h2 = lrelu(instance_norm(conv2d(h1, options.df_dim*4, name='d_h2_conv'), 'd_bn2'))
# h2 is (32x 32 x self.df_dim*4)
h3 = lrelu(instance_norm(conv2d(h2, options.df_dim*8, s=2, name='d_h3_conv'), 'd_bn3'))
# h3 is (16x 16 x self.df_dim*8)
h4 = lrelu(instance_norm(conv2d(h3, options.df_dim * 8, s=2, name='d_h4_conv'), 'd_bn4'))
# h4 is (8x 8 x self.df_dim*8)
h5 = lrelu(instance_norm(conv2d(h4, options.df_dim * 8, s=2, name='d_h5_conv'), 'd_bn5'))
# h5 is (4 x 4 x self.df_dim*8)
h5 = conv2d(h5, 1, s=1, name='d_h3_pred')
# h5 is (4 x 4 x 1)
return h5
def fine_discriminator(image, options, reuse=False, name="discriminator_fine"):
with tf.variable_scope(name):
if reuse:
tf.get_variable_scope().reuse_variables()
else:
assert tf.get_variable_scope().reuse is False
h0 = lrelu(conv2d(image, options.df_dim, name='d_h0_conv'))
# h0 is (128 x 128 x self.df_dim)
h1 = lrelu(instance_norm(conv2d(h0, options.df_dim * 2, name='d_h1_conv'), 'd_bn1'))
# h1 is (64 x 64 x self.df_dim*2)
h2 = lrelu(instance_norm(conv2d(h1, options.df_dim * 4, name='d_h2_conv'), 'd_bn2'))
# h2 is (32x 32 x self.df_dim*4)
h3 = lrelu(instance_norm(conv2d(h2, options.df_dim * 8, s=2, name='d_h3_conv'), 'd_bn3'))
# h3 is (16x 16 x self.df_dim*8)
h4 = conv2d(h3, 1, s=1, name='d_h3_pred')
# h4 is (16x 16 x 1)
return h4
def generator_resnet(image, options, reuse=False, name="generator"):
with tf.variable_scope(name):
# image is 256 x 256 x input_c_dim
if reuse:
tf.get_variable_scope().reuse_variables()
else:
assert tf.get_variable_scope().reuse is False
def residule_block(x, dim, ks=3, s=1, name='res'):
p = int((ks - 1) / 2)
y = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c1'), name+'_bn1')
y = tf.pad(tf.nn.relu(y), [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
y = instance_norm(conv2d(y, dim, ks, s, padding='VALID', name=name+'_c2'), name+'_bn2')
return y + x
# Justin Johnson's model from https://github.com/jcjohnson/fast-neural-style/
# The network with 9 blocks consists of: c7s1-32, d64, d128, R128, R128, R128,
# R128, R128, R128, R128, R128, R128, u64, u32, c7s1-3
c0 = tf.pad(image, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
c1 = tf.nn.relu(instance_norm(conv2d(c0, options.gf_dim, 7, 1, padding='VALID', name='g_e1_c'), 'g_e1_bn'))
c2 = tf.nn.relu(instance_norm(conv2d(c1, options.gf_dim*2, 3, 2, name='g_e2_c'), 'g_e2_bn'))
c3 = tf.nn.relu(instance_norm(conv2d(c2, options.gf_dim*4, 3, 2, name='g_e3_c'), 'g_e3_bn'))
# define G network with 9 resnet blocks
r1 = residule_block(c3, options.gf_dim*4, name='g_r1')
r2 = residule_block(r1, options.gf_dim*4, name='g_r2')
r3 = residule_block(r2, options.gf_dim*4, name='g_r3')
r4 = residule_block(r3, options.gf_dim*4, name='g_r4')
r5 = residule_block(r4, options.gf_dim*4, name='g_r5')
r6 = residule_block(r5, options.gf_dim*4, name='g_r6')
r7 = residule_block(r6, options.gf_dim*4, name='g_r7')
r8 = residule_block(r7, options.gf_dim*4, name='g_r8')
r9 = residule_block(r8, options.gf_dim*4, name='g_r9')
d1 = deconv2d(r9, options.gf_dim*2, 3, 2, name='g_d1_dc')
d1 = tf.nn.relu(instance_norm(d1, 'g_d1_bn'))
d2 = deconv2d(d1, options.gf_dim, 3, 2, name='g_d2_dc')
d2 = tf.nn.relu(instance_norm(d2, 'g_d2_bn'))
d2 = tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
pred = tf.nn.tanh(conv2d(d2, options.output_c_dim, 7, 1, padding='VALID', name='g_pred_c'))
return pred
def abs_criterion(in_, target):
return tf.reduce_mean(tf.abs(in_ - target))
def mae_criterion(in_, target):
return tf.reduce_mean((in_-target)**2)
def sce_criterion(logits, labels):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits, labels=labels))
class cyclegan(object):
def __init__(self, sess, args):
self.sess = sess
self.batch_size = args.batch_size
self.image_size = args.fine_size
self.input_c_dim = args.input_nc
self.output_c_dim = args.output_nc
self.L1_lambda = args.L1_lambda
self.dataset_dir = args.dataset_dir
self.discriminator = discriminator
self.fine_discriminator = fine_discriminator
self.generator = generator_resnet
if args.use_lsgan:
self.criterionGAN = mae_criterion
else:
self.criterionGAN = sce_criterion
OPTIONS = namedtuple('OPTIONS', 'batch_size image_size \
gf_dim df_dim output_c_dim is_training')
self.options = OPTIONS._make((args.batch_size, args.fine_size,
args.ngf, args.ndf, args.output_nc,
args.phase == 'train'))
self._build_model()
self.saver = tf.train.Saver()
self.pool = ImagePool(args.max_size)
def _build_model(self):
self.real_data = tf.placeholder(tf.float32,
[None, self.image_size, self.image_size,
self.input_c_dim + self.output_c_dim],
name='real_A_and_B_images')
self.real_A = self.real_data[:, :, :, :self.input_c_dim]
self.real_B = self.real_data[:, :, :, self.input_c_dim:self.input_c_dim + self.output_c_dim]
self.fake_B = self.generator(self.real_A, self.options, False, name="generatorA2B")
self.fake_A_ = self.generator(self.fake_B, self.options, False, name="generatorB2A")
self.fake_A = self.generator(self.real_B, self.options, True, name="generatorB2A")
self.fake_B_ = self.generator(self.fake_A, self.options, True, name="generatorA2B")
self.DB_fake = self.discriminator(self.fake_B, self.options, reuse=False, name="discriminatorB_coarse")
self.DA_fake = self.discriminator(self.fake_A, self.options, reuse=False, name="discriminatorA_coarse")
self.DB_fake_fine = self.fine_discriminator(self.fake_B, self.options, reuse=False, name="discriminatorB_fine")
self.DA_fake_fine= self.fine_discriminator(self.fake_A, self.options, reuse=False, name="discriminatorA_fine")
self.g_loss_a2b = self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \
+ self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
+ self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
self.g_loss_b2a = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \
+ self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
+ self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
self.g_loss = self.criterionGAN(self.DA_fake, tf.ones_like(self.DA_fake)) \
+ self.criterionGAN(self.DB_fake, tf.ones_like(self.DB_fake)) \
+ self.criterionGAN(self.DA_fake_fine, tf.ones_like(self.DA_fake_fine)) \
+ self.criterionGAN(self.DB_fake_fine, tf.ones_like(self.DB_fake_fine)) \
+ self.L1_lambda * abs_criterion(self.real_A, self.fake_A_) \
+ self.L1_lambda * abs_criterion(self.real_B, self.fake_B_)
self.fake_A_sample = tf.placeholder(tf.float32,
[None, self.image_size, self.image_size,
self.input_c_dim], name='fake_A_sample')
self.fake_B_sample = tf.placeholder(tf.float32,
[None, self.image_size, self.image_size,
self.output_c_dim], name='fake_B_sample')
self.DB_real = self.discriminator(self.real_B, self.options, reuse=True, name="discriminatorB_coarse")
self.DA_real = self.discriminator(self.real_A, self.options, reuse=True, name="discriminatorA_coarse")
self.DB_fake_sample = self.discriminator(self.fake_B_sample, self.options, reuse=True, name="discriminatorB_coarse")
self.DA_fake_sample = self.discriminator(self.fake_A_sample, self.options, reuse=True, name="discriminatorA_coarse")
self.DB_real_fine = self.fine_discriminator(self.real_B, self.options, reuse=True, name="discriminatorB_fine")
self.DA_real_fine = self.fine_discriminator(self.real_A, self.options, reuse=True, name="discriminatorA_fine")
self.DB_fake_sample_fine = self.fine_discriminator(self.fake_B_sample, self.options, reuse=True, name="discriminatorB_fine")
self.DA_fake_sample_fine = self.fine_discriminator(self.fake_A_sample, self.options, reuse=True, name="discriminatorB_fine")
self.db_loss_real = self.criterionGAN(self.DB_real, tf.ones_like(self.DB_real))
self.db_loss_fake = self.criterionGAN(self.DB_fake_sample, tf.zeros_like(self.DB_fake_sample))
self.db_loss = (self.db_loss_real + self.db_loss_fake) / 2
self.da_loss_real = self.criterionGAN(self.DA_real, tf.ones_like(self.DA_real))
self.da_loss_fake = self.criterionGAN(self.DA_fake_sample, tf.zeros_like(self.DA_fake_sample))
self.da_loss = (self.da_loss_real + self.da_loss_fake) / 2
self.d_loss = (self.da_loss + self.db_loss)
self.db_loss_real_fine = self.criterionGAN(self.DB_real_fine, tf.ones_like(self.DB_real_fine))
self.db_loss_fake_fine = self.criterionGAN(self.DB_fake_sample_fine, tf.zeros_like(self.DB_fake_sample_fine))
self.db_loss_fine = (self.db_loss_real_fine + self.db_loss_fake_fine) / 2
self.da_loss_real_fine = self.criterionGAN(self.DA_real_fine, tf.ones_like(self.DA_real_fine))
self.da_loss_fake_fine = self.criterionGAN(self.DA_fake_sample_fine, tf.zeros_like(self.DA_fake_sample_fine))
self.da_loss_fine = (self.da_loss_real_fine + self.da_loss_fake_fine) / 2
self.d_loss_fine = self.da_loss_fine + self.db_loss_fine
self.g_loss_a2b_sum = tf.summary.scalar("g_loss_a2b", self.g_loss_a2b)
self.g_loss_b2a_sum = tf.summary.scalar("g_loss_b2a", self.g_loss_b2a)
self.g_loss_sum = tf.summary.scalar("g_loss", self.g_loss)
self.g_sum = tf.summary.merge([self.g_loss_a2b_sum, self.g_loss_b2a_sum, self.g_loss_sum])
self.db_loss_sum = tf.summary.scalar("db_loss", self.db_loss)
self.da_loss_sum = tf.summary.scalar("da_loss", self.da_loss)
self.d_loss_sum = tf.summary.scalar("d_loss", self.d_loss)
self.db_loss_real_sum = tf.summary.scalar("db_loss_real", self.db_loss_real)
self.db_loss_fake_sum = tf.summary.scalar("db_loss_fake", self.db_loss_fake)
self.da_loss_real_sum = tf.summary.scalar("da_loss_real", self.da_loss_real)
self.da_loss_fake_sum = tf.summary.scalar("da_loss_fake", self.da_loss_fake)
self.d_sum = tf.summary.merge(
[self.da_loss_sum, self.da_loss_real_sum, self.da_loss_fake_sum,
self.db_loss_sum, self.db_loss_real_sum, self.db_loss_fake_sum,
self.d_loss_sum]
)
self.test_A = tf.placeholder(tf.float32,
[None, self.image_size, self.image_size,
self.input_c_dim], name='test_A')
self.test_B = tf.placeholder(tf.float32,
[None, self.image_size, self.image_size,
self.output_c_dim], name='test_B')
self.testB = self.generator(self.test_A, self.options, True, name="generatorA2B")
self.testA = self.generator(self.test_B, self.options, True, name="generatorB2A")
t_vars = tf.trainable_variables()
self.d_vars = [var for var in t_vars if 'discriminator' in var.name]
self.d_vars_fine = [var for var in t_vars if 'fine' in var.name]
self.g_vars = [var for var in t_vars if 'generator' in var.name]
for var in t_vars: print(var.name)
def train(self, args):
"""Train cyclegan"""
self.lr = tf.placeholder(tf.float32, None, name='learning_rate')
self.d_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
.minimize(self.d_loss, var_list=self.d_vars)
self.d_optim_fine = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
.minimize(self.d_loss_fine, var_list=self.d_vars_fine)
self.g_optim = tf.train.AdamOptimizer(self.lr, beta1=args.beta1) \
.minimize(self.g_loss, var_list=self.g_vars)
init_op = tf.global_variables_initializer()
self.sess.run(init_op)
self.writer = tf.summary.FileWriter("./logs", self.sess.graph)
counter = 1
start_time = time.time()
if args.continue_train:
if self.load(args.checkpoint_dir):
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
for epoch in range(args.epoch):
dataA = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/trainA'))
dataB = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/trainB'))
np.random.shuffle(dataA)
np.random.shuffle(dataB)
batch_idxs = min(min(len(dataA), len(dataB)), args.train_size) // self.batch_size
lr = args.lr if epoch < args.epoch_step else args.lr*(args.epoch-epoch)/(args.epoch-args.epoch_step)
for idx in range(0, batch_idxs):
batch_files = list(zip(dataA[idx * self.batch_size:(idx + 1) * self.batch_size],
dataB[idx * self.batch_size:(idx + 1) * self.batch_size]))
batch_images = [load_train_data(batch_file, args.load_size, args.fine_size) for batch_file in batch_files]
batch_images = np.array(batch_images).astype(np.float32)
# Update G network and record fake outputs
fake_A, fake_B, _, summary_str = self.sess.run(
[self.fake_A, self.fake_B, self.g_optim, self.g_sum],
feed_dict={self.real_data: batch_images, self.lr: lr})
self.writer.add_summary(summary_str, counter)
[fake_A, fake_B] = self.pool([fake_A, fake_B])
# Update D network
_, _,summary_str = self.sess.run(
[self.d_optim, self.d_optim_fine,self.d_sum],
feed_dict={self.real_data: batch_images,
self.fake_A_sample: fake_A,
self.fake_B_sample: fake_B,
self.lr: lr})
self.writer.add_summary(summary_str, counter)
counter += 1
print(("Epoch: [%2d] [%4d/%4d] time: %4.4f" % (
epoch, idx, batch_idxs, time.time() - start_time)))
if np.mod(counter, args.print_freq) == 1:
self.sample_model(args.sample_dir, epoch, idx)
if np.mod(counter, args.save_freq) == 2:
self.save(args.checkpoint_dir, counter)
def save(self, checkpoint_dir, step):
model_name = "cyclegan.model"
model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
self.saver.save(self.sess,
os.path.join(checkpoint_dir, model_name),
global_step=step)
def load(self, checkpoint_dir):
print(" [*] Reading checkpoint...")
model_dir = "%s_%s" % (self.dataset_dir, self.image_size)
checkpoint_dir = os.path.join(checkpoint_dir, model_dir)
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
self.saver.restore(self.sess, os.path.join(checkpoint_dir, ckpt_name))
return True
else:
return False
def sample_model(self, sample_dir, epoch, idx):
dataA = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testA'))
dataB = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testB'))
np.random.shuffle(dataA)
np.random.shuffle(dataB)
batch_files = list(zip(dataA[:self.batch_size], dataB[:self.batch_size]))
sample_images = [load_train_data(batch_file, is_testing=True) for batch_file in batch_files]
sample_images = np.array(sample_images).astype(np.float32)
fake_A, fake_B = self.sess.run(
[self.fake_A, self.fake_B],
feed_dict={self.real_data: sample_images}
)
save_images(fake_A, [self.batch_size, 1],
'./{}/A_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx))
save_images(fake_B, [self.batch_size, 1],
'./{}/B_{:02d}_{:04d}.jpg'.format(sample_dir, epoch, idx))
def test(self, args):
"""Test cyclegan"""
init_op = tf.global_variables_initializer()
self.sess.run(init_op)
if args.which_direction == 'AtoB':
sample_files = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testA'))
elif args.which_direction == 'BtoA':
sample_files = glob('./datasets/{}/*.*'.format(self.dataset_dir + '/testB'))
else:
raise Exception('--which_direction must be AtoB or BtoA')
if self.load(args.checkpoint_dir):
print(" [*] Load SUCCESS")
else:
print(" [!] Load failed...")
# write html for visual comparison
index_path = os.path.join(args.test_dir, '{0}_index.html'.format(args.which_direction))
index = open(index_path, "w")
index.write("<html><body><table><tr>")
index.write("<th>name</th><th>input</th><th>output</th></tr>")
out_var, in_var = (self.testB, self.test_A) if args.which_direction == 'AtoB' else (
self.testA, self.test_B)
for sample_file in sample_files:
print('Processing image: ' + sample_file)
sample_image = [load_test_data(sample_file, args.fine_size)]
sample_image = np.array(sample_image).astype(np.float32)
image_path = os.path.join(args.test_dir,
'{0}_{1}'.format(args.which_direction, os.path.basename(sample_file)))
fake_img = self.sess.run(out_var, feed_dict={in_var: sample_image})
save_images(fake_img, [1, 1], image_path)
index.write("<td>%s</td>" % os.path.basename(image_path))
index.write("<td><img src='%s'></td>" % (sample_file if os.path.isabs(sample_file) else (
'..' + os.path.sep + sample_file)))
index.write("<td><img src='%s'></td>" % (image_path if os.path.isabs(image_path) else (
'..' + os.path.sep + image_path)))
index.write("</tr>")
index.close()
parser = argparse.ArgumentParser(description='')
parser.add_argument('--dataset_dir', dest='dataset_dir', default='sample', help='path of the dataset')
parser.add_argument('--epoch', dest='epoch', type=int, default=200, help='# of epoch')
parser.add_argument('--epoch_step', dest='epoch_step', type=int, default=100, help='# of epoch to decay lr')
parser.add_argument('--batch_size', dest='batch_size', type=int, default=1, help='# images in batch')
parser.add_argument('--train_size', dest='train_size', type=int, default=1e8, help='# images used to train')
parser.add_argument('--load_size', dest='load_size', type=int, default=286, help='scale images to this size')
parser.add_argument('--fine_size', dest='fine_size', type=int, default=256, help='then crop to this size')
parser.add_argument('--ngf', dest='ngf', type=int, default=64, help='# of gen filters in first conv layer')
parser.add_argument('--ndf', dest='ndf', type=int, default=64, help='# of discri filters in first conv layer')
parser.add_argument('--input_nc', dest='input_nc', type=int, default=3, help='# of input image channels')
parser.add_argument('--output_nc', dest='output_nc', type=int, default=3, help='# of output image channels')
parser.add_argument('--lr', dest='lr', type=float, default=0.0002, help='initial learning rate for adam')
parser.add_argument('--beta1', dest='beta1', type=float, default=0.5, help='momentum term of adam')
parser.add_argument('--which_direction', dest='which_direction', default='AtoB', help='AtoB or BtoA')
parser.add_argument('--phase', dest='phase', default='train', help='train, test')
parser.add_argument('--save_freq', dest='save_freq', type=int, default=1000, help='save a model every save_freq iterations')
parser.add_argument('--print_freq', dest='print_freq', type=int, default=100, help='print the debug information every print_freq iterations')
parser.add_argument('--continue_train', dest='continue_train', type=bool, default=False, help='if continue training, load the latest model: 1: true, 0: false')
parser.add_argument('--checkpoint_dir', dest='checkpoint_dir', default='./checkpoint', help='models are saved here')
parser.add_argument('--sample_dir', dest='sample_dir', default='./sample', help='sample are saved here')
parser.add_argument('--test_dir', dest='test_dir', default='./test', help='test sample are saved here')
parser.add_argument('--L1_lambda', dest='L1_lambda', type=float, default=10.0, help='weight on L1 term in objective')
parser.add_argument('--use_lsgan', dest='use_lsgan', type=bool, default=True, help='gan loss defined in lsgan')
parser.add_argument('--max_size', dest='max_size', type=int, default=50, help='max size of image pool, 0 means do not use image pool')
parser.add_argument('--gpu', dest='gpu', type=int, default=0, help='# of gpu index)
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"]=str(args.gpu)
def main(_):
if not os.path.exists(args.checkpoint_dir):
os.makedirs(args.checkpoint_dir)
if not os.path.exists(args.sample_dir):
os.makedirs(args.sample_dir)
if not os.path.exists(args.test_dir):
os.makedirs(args.test_dir)
tfconfig = tf.ConfigProto(allow_soft_placement=True)
tfconfig.gpu_options.allow_growth = True
with tf.Session(config=tfconfig) as sess:
model = cyclegan(sess, args)
model.train(args) if args.phase == 'train' \
else model.test(args)
if __name__ == '__main__':
tf.app.run()