Skip to content

Commit

Permalink
Merge pull request #14 from mlberkeley/sep-classification
Browse files Browse the repository at this point in the history
Adding separate classification network to the model
  • Loading branch information
philkuz authored Jan 4, 2018
2 parents 249fead + 495cf3b commit 1f35058
Show file tree
Hide file tree
Showing 102 changed files with 93,839 additions and 127 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ __pycache__/
*.py[cod]
*$py.class

# Jupyter notebooks
.ipynb_checkpoints/
*/.ipynb_checkpoints/

# C extensions
*.so

Expand Down Expand Up @@ -88,3 +92,5 @@ Session.vim
*~


# philkuz random folders
_*/
23 changes: 23 additions & 0 deletions experiments/wiki_external_can.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# trains gan with an outside can network instead of having the discriminator learn style classification
export PYTHONPATH="slim/:$PYTHONPATH"
export CUDA_VISIBLE_DEVICES=0
BATCH_SIZE=16
python3 main.py \
--epoch 25 \
--learning_rate .0001 \
--beta 0.5 \
--batch_size $BATCH_SIZE \
--sample_size $BATCH_SIZE \
--input_height 256 \
--output_height 256 \
--lambda_val 1.0 \
--smoothing 1.0 \
--use_resize True \
--dataset wikiart \
--input_fname_pattern */*.jpg \
--crop False \
--visualize False \
--use_s3 False \
--can True \
--train \
--style_net_checkpoint "slim/logs/wikiart/inception_resnet_v2/all/bs=16,lr=0.0001,epochs=100/smol_adam_fixedLR"
22 changes: 22 additions & 0 deletions experiments/wiki_exxternal_can_ht=128,bs=16.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# trains gan with an outside can network instead of having the discriminator learn style classification
export PYTHONPATH="slim/:$PYTHONPATH"
export CUDA_VISIBLE_DEVICES=1
python3 main.py \
--epoch 25 \
--learning_rate .0001 \
--beta 0.5 \
--batch_size 16 \
--sample_size 72 \
--input_height 128 \
--output_height 128 \
--lambda_val 1.0 \
--smoothing 1.0 \
--use_resize True \
--dataset wikiart \
--input_fname_pattern */*.jpg \
--crop False \
--visualize False \
--use_s3 False \
--can True \
--train \
--style_net_checkpoint "slim/logs/wikiart/inception_resnet_v2/all/bs=16,lr=0.0001,epochs=100/smol_adam_fixedLR"
22 changes: 15 additions & 7 deletions losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
def CAN_loss(model):
#builds optimizers and losses

model.G = model.generator(model.z)
model.D, model.D_logits, model.D_c, model.D_c_logits = model.discriminator(
model.G = model.generator(model, model.z)
model.D, model.D_logits, model.D_c, model.D_c_logits = model.discriminator(model,
model.inputs, reuse=False)
if model.experience_flag:
try:
Expand All @@ -14,7 +14,7 @@ def CAN_loss(model):
model.experience_selection = tf.convert_to_tensor(model.experience_buffer)
model.G = tf.concat([model.G, model.experience_selection], axis=0)

model.D_, model.D_logits_, model.D_c_, model.D_c_logits_ = model.discriminator(
model.D_, model.D_logits_, model.D_c_, model.D_c_logits_ = model.discriminator(model,
model.G, reuse=True)
model.d_sum = histogram_summary("d", model.D)
model.d__sum = histogram_summary("d_", model.D_)
Expand All @@ -36,9 +36,17 @@ def CAN_loss(model):

model.d_loss_class_real = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=model.D_c_logits, labels=model.smoothing * model.y))
model.g_loss_class_fake = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=model.D_c_logits_,
labels=(1.0/model.y_dim)*tf.ones_like(model.D_c_)))

# if classifier is set, then use the classifier, o/w use the clasification layers in the discriminator
if model.style_net_checkpoint is None:
model.g_loss_class_fake = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=model.D_c_logits_,
labels=(1.0/model.y_dim)*tf.ones_like(model.D_c_)))
else:
model.classifier = model.make_style_net(model.G)
model.g_loss_class_fake = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(logits=model.classifier,
labels=(1.0/model.y_dim)*tf.ones_like(model.D_c_)))

model.g_loss_fake = -tf.reduce_mean(tf.log(model.D_))

Expand Down Expand Up @@ -139,7 +147,7 @@ def WGAN_loss(model):
t_vars = tf.trainable_variables()
model.d_vars = [var for var in t_vars if 'd_' in var.name]
model.g_vars = [var for var in t_vars if 'g_' in var.name]

g_update = model.g_opt.minimize(model.g_loss, var_list=model.g_vars)
d_update = model.d_opt.minimize(model.d_loss, var_list=model.d_vars)

Expand Down
126 changes: 50 additions & 76 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@

flags = tf.app.flags
flags.DEFINE_integer("epoch", 25, "Epoch to train [25]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate of for adam [0.0002]")
flags.DEFINE_float("learning_rate", 0.0002, "Learning rate for adam [0.0002]")
flags.DEFINE_float("beta1", 0.5, "Momentum term of adam [0.5]")
flags.DEFINE_float("smoothing", 0.9, "Smoothing term for discriminator real (class) loss [0.9]")
flags.DEFINE_float("lambda_val", 1.0, "determines the relative importance of style ambiguity loss [1.0]")
flags.DEFINE_integer("train_size", np.inf, "The size of train images [np.inf]")
flags.DEFINE_integer("save_itr", 500, "The number of iterations to run for saving checkpoints")
flags.DEFINE_integer("sample_itr", 500, "The number of iterations to run for sampling from the sampler")
flags.DEFINE_integer("batch_size", 64, "The size of batch images [64]")
flags.DEFINE_integer("sample_size", 64, "the size of sample images [64]")
flags.DEFINE_integer("input_height", 108, "The size of image to use (will be center cropped). [108]")
Expand All @@ -37,6 +38,8 @@
flags.DEFINE_boolean("replay", True, "True if using experience replay [True]")
flags.DEFINE_boolean("use_resize", False, "True if resize conv for upsampling, False for fractionally strided conv [False]")
flags.DEFINE_boolean("use_default_checkpoint", False, "True only if checkpoint_dir is None. Don't set this")
flags.DEFINE_string("style_net_checkpoint", None, "The checkpoint to get style net. Leave default to note use stylenet")
flags.DEFINE_boolean("allow_gpu_growth", False, "True if you want Tensorflow only to allocate the gpu memory it requires. Good for debugging, but can impact performance")
FLAGS = flags.FLAGS

def main(_):
Expand All @@ -59,21 +62,23 @@ def main(_):


# configure the log_dir to match the params
log_dir = os.path.join(FLAGS.log_dir, "dataset={},isCan={},lr={},imsize={},batch_size={}".format(
log_dir = os.path.join(FLAGS.log_dir, "dataset={},isCan={},lr={},imsize={},hasStyleNet={},batch_size={}".format(
FLAGS.dataset,
FLAGS.can,
FLAGS.learning_rate,
FLAGS.input_height,
FLAGS.style_net_checkpoint is not None,
FLAGS.batch_size))
if not glob(log_dir + "*"):
log_dir = os.path.join(log_dir, "000")
else:
containing_dir=os.path.join(log_dir, "*")
print(containing_dir)
containing_dir=os.path.join(log_dir, "[0-9][0-9][0-9]")
nums = [int(x[-3:]) for x in glob(containing_dir)] # TODO FIX THESE HACKS
print('nums', nums)
num = str(max(nums) + 1)
log_dir = os.path.join(log_dir,(3-len(num))*"0"+num)
if nums == []:
num = 0
else:
num = max(nums) + 1
log_dir = os.path.join(log_dir,"{:03d}".format(num))
FLAGS.log_dir = log_dir

if FLAGS.checkpoint_dir is None:
Expand All @@ -91,78 +96,47 @@ def main(_):
os.makedirs(FLAGS.sample_dir)
print('After processing flags')
pp.pprint(flags.FLAGS.__flags)
if FLAGS.style_net_checkpoint:
from slim.nets import nets_factory
network_fn = nets_factory

run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth=True

sess = None
if FLAGS.dataset == 'mnist':
y_dim = 10
elif FLAGS.dataset == 'wikiart':
y_dim = 27
else:
y_dim = None
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.sample_size,
use_resize=FLAGS.use_resize,
replay=FLAGS.replay,
y_dim=y_dim,
smoothing=FLAGS.smoothing,
lamb = FLAGS.lambda_val,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
wgan=FLAGS.wgan,
learning_rate = FLAGS.learning_rate,
style_net_checkpoint=FLAGS.style_net_checkpoint,
can=FLAGS.can)


run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth=FLAGS.allow_gpu_growth
with tf.Session(config=run_config) as sess:
if FLAGS.dataset == 'mnist':
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.sample_size,
use_resize=FLAGS.use_resize,
replay=FLAGS.replay,
y_dim=10,
learning_rate = FLAGS.learning_rate,
smoothing=FLAGS.smoothing,
lamb = FLAGS.lambda_val,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
wgan=FLAGS.wgan,
can=FLAGS.can)
elif FLAGS.dataset == 'wikiart':
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.sample_size,
use_resize=FLAGS.use_resize,
replay=FLAGS.replay,
y_dim=27,
learning_rate = FLAGS.learning_rate,
smoothing=FLAGS.smoothing,
lamb = FLAGS.lambda_val,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
wgan=FLAGS.wgan,
can=FLAGS.can)
else:
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.sample_size,
dataset_name=FLAGS.dataset,
replay=FLAGS.replay,
input_fname_pattern=FLAGS.input_fname_pattern,
use_resize=FLAGS.use_resize,
smoothing=FLAGS.smoothing,
learning_rate = FLAGS.learning_rate,
crop=FLAGS.crop,
lamb = FLAGS.lambda_val,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
wgan=FLAGS.wgan,
can=FLAGS.can)

show_all_variables()
dcgan.set_sess(sess)
# show_all_variables()

if FLAGS.train:
dcgan.train(FLAGS)
Expand Down
Loading

0 comments on commit 1f35058

Please sign in to comment.