-
Notifications
You must be signed in to change notification settings - Fork 2
/
optimizer.py
88 lines (70 loc) · 5.73 KB
/
optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import tensorflow as tf
def optimizer(beta_1, loss_gen, loss_dis, loss_type, learning_rate_input_g, learning_rate_input_d, beta_2=None, clipping=None, display=True):
trainable_variables = tf.trainable_variables()
generator_variables = [variable for variable in trainable_variables if variable.name.startswith('generator')]
discriminator_variables = [variable for variable in trainable_variables if variable.name.startswith('discriminator')]
# Optimizer variable to track with optimizer is actually used.
optimizer_print = ''
# Handling Batch Normalization.
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
if ('wasserstein distance' in loss_type and 'gradient penalty' in loss_type) or ('hinge' in loss_type):
train_discriminator = tf.train.AdamOptimizer(learning_rate_input_d, beta_1, beta_2).minimize(loss_dis, var_list=discriminator_variables)
train_generator = tf.train.AdamOptimizer(learning_rate_input_g, beta_1, beta_2).minimize(loss_gen, var_list=generator_variables)
optimizer_print += '%s - AdamOptimizer' % loss_type
#TODO Fix this for RMSProp.
elif 'wasserstein distance' in loss_type and 'gradient penalty' not in loss_type:
# Weight Clipping on Discriminator, this is done to ensure the Lipschitz constrain.
train_discriminator = tf.train.AdamOptimizer(learning_rate_input_d, beta_1, beta_2).minimize(loss_dis, var_list=discriminator_variables)
dis_weight_clipping = [value.assign(tf.clip_by_value(value, -clipping, clipping)) for value in discriminator_variables]
train_discriminator = tf.group(*[train_discriminator, dis_weight_clipping])
train_generator = tf.train.AdamOptimizer(learning_rate_input_g, beta_1, beta_2).minimize(loss_gen, var_list=generator_variables)
optimizer_print += '%s - AdamOptimizer' % loss_type
'''
RMS_optimizer_dis = tf.train.RMSPropOptimizer(learning_rate=learning_rate)
train_discriminator = RMS_optimizer_dis.minimize(loss_dis , var_list=discriminator_variables)
# Weight Clipping on Discriminator, this is done to ensure the Lipschitz constrain.
dis_weight_clipping = [value.assign(tf.clip_by_value(value, -c, c)) for value in discriminator_variables]
train_discriminator = tf.group(*[train_discriminator, dis_weight_clipping])
# Generator.
RMS_optimizer_gen = tf.train.RMSPropOptimizer(learning_rate=learning_rate)
train_generator = RMS_optimizer_gen.minimize(loss_gen, var_list=generator_variables)
optimizer_print += 'Wassertein Distance - RMSPropOptimizer'
'''
elif 'standard' in loss_type or 'least square' in loss_type or 'relativistic' in loss_type:
train_discriminator = tf.train.AdamOptimizer(learning_rate=learning_rate_input_d, beta1=beta_1).minimize(loss_dis, var_list=discriminator_variables)
train_generator = tf.train.AdamOptimizer(learning_rate=learning_rate_input_g, beta1=beta_1).minimize(loss_gen, var_list=generator_variables)
optimizer_print += '%s - AdamOptimizer' % loss_type
else:
print('Optimizer: Loss %s not defined' % loss_type)
exit(1)
if display:
print('[Optimizer] Loss %s' % optimizer_print)
print()
return train_discriminator, train_generator
def vae_gan_optimizer(beta_1, loss_prior, loss_dist_likel, loss_gen, loss_dis, loss_type, learning_rate_input_g, learning_rate_input_d, beta_2=None, clipping=None,
display=True, gamma=1):
trainable_variables = tf.trainable_variables()
encoder_variables = [variable for variable in trainable_variables if variable.name.startswith('encoder')]
generator_decoder_variables = [variable for variable in trainable_variables if variable.name.startswith('generator_decoder')]
discriminator_variables = [variable for variable in trainable_variables if variable.name.startswith('discriminator')]
# Optimizer variable to track with optimizer is actually used.
optimizer_print = ''
# Handling Batch Normalization.
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
if 'wasserstein distance' in loss_type and 'gradient penalty' in loss_type:
train_encoder = tf.train.AdamOptimizer(learning_rate_input_d, beta_1, beta_2).minimize(loss_prior+loss_dist_likel, var_list=encoder_variables)
train_gen_decod = tf.train.AdamOptimizer(learning_rate_input_d, beta_1, beta_2).minimize(loss_dist_likel+loss_gen, var_list=generator_decoder_variables)
train_discriminator = tf.train.AdamOptimizer(learning_rate_input_d, beta_1, beta_2).minimize(loss_dis, var_list=discriminator_variables)
optimizer_print += 'Wasserstein Distance Gradient penalty - AdamOptimizer'
elif 'relativistic' in loss_type:
train_encoder = tf.train.AdamOptimizer(learning_rate=learning_rate_input_d, beta1=beta_1).minimize(loss_prior+loss_dist_likel, var_list=encoder_variables)
train_gen_decod = tf.train.AdamOptimizer(learning_rate=learning_rate_input_g, beta1=beta_1).minimize((gamma*loss_dist_likel)+loss_gen, var_list=generator_decoder_variables)
train_discriminator = tf.train.AdamOptimizer(learning_rate=learning_rate_input_d, beta1=beta_1).minimize(loss_dis, var_list=discriminator_variables)
optimizer_print += '%s - AdamOptimizer' % loss_type
else:
print('Loss %s not defined' % loss_type)
exit(1)
if display:
print('Optimizer: %s' % optimizer_print)
print()
return train_encoder, train_gen_decod, train_discriminator