diff --git a/model.py b/model.py index a9115653..1beaea09 100644 --- a/model.py +++ b/model.py @@ -5,9 +5,9 @@ import numpy as np class Model(): - def __init__(self, args, infer=False): + def __init__(self, args, training=True): self.args = args - if infer: + if not training: args.batch_size = 1 args.seq_length = 1 @@ -22,6 +22,9 @@ def __init__(self, args, infer=False): cell = cell_fn(args.rnn_size) + if training and args.keep_prob < 1: + cell = rnn_cell.DropoutWrapper(cell, output_keep_prob=args.keep_prob) + self.cell = cell = rnn_cell.MultiRNNCell([cell] * args.num_layers) self.input_data = tf.placeholder(tf.int32, [args.batch_size, args.seq_length]) @@ -33,12 +36,30 @@ def __init__(self, args, infer=False): self.batch_time = tf.Variable(0.0, name="batch_time", trainable=False) tf.summary.scalar("time_batch", self.batch_time) + def variable_summaries(var): + """Attach a lot of summaries to a Tensor (for TensorBoard visualization).""" + with tf.name_scope('summaries'): + mean = tf.reduce_mean(var) + tf.summary.scalar('mean', mean) + #with tf.name_scope('stddev'): + # stddev = tf.sqrt(tf.reduce_mean(tf.square(var - mean))) + #tf.summary.scalar('stddev', stddev) + tf.summary.scalar('max', tf.reduce_max(var)) + tf.summary.scalar('min', tf.reduce_min(var)) + #tf.summary.histogram('histogram', var) + with tf.variable_scope('rnnlm'): softmax_w = tf.get_variable("softmax_w", [args.rnn_size, args.vocab_size]) + variable_summaries(softmax_w) softmax_b = tf.get_variable("softmax_b", [args.vocab_size]) + variable_summaries(softmax_b) with tf.device("/cpu:0"): embedding = tf.get_variable("embedding", [args.vocab_size, args.rnn_size]) - inputs = tf.split(1, args.seq_length, tf.nn.embedding_lookup(embedding, self.input_data)) + inputs = tf.nn.embedding_lookup(embedding, self.input_data) + if training and args.keep_prob < 1: + inputs = tf.nn.dropout(inputs, args.keep_prob) + + inputs = tf.split(1, args.seq_length, inputs) inputs = [tf.squeeze(input_, [1]) for input_ in inputs] def loop(prev, _): @@ -46,7 +67,7 @@ def loop(prev, _): prev_symbol = tf.stop_gradient(tf.argmax(prev, 1)) return tf.nn.embedding_lookup(embedding, prev_symbol) - outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if infer else None, scope='rnnlm') + outputs, last_state = seq2seq.rnn_decoder(inputs, self.initial_state, cell, loop_function=loop if not training else None, scope='rnnlm') output = tf.reshape(tf.concat(1, outputs), [-1, args.rnn_size]) self.logits = tf.matmul(output, softmax_w) + softmax_b self.probs = tf.nn.softmax(self.logits) @@ -57,12 +78,15 @@ def loop(prev, _): self.cost = tf.reduce_sum(loss) / args.batch_size / args.seq_length tf.summary.scalar("cost", self.cost) self.final_state = last_state - self.lr = tf.Variable(0.0, trainable=False) + self.global_step = tf.Variable(0, name='global_step', trainable=False) + self.lr = tf.train.exponential_decay(args.learning_rate, self.global_step, + args.decay_step, args.decay_rate) + tf.summary.scalar("learning_rate", self.lr) tvars = tf.trainable_variables() grads, _ = tf.clip_by_global_norm(tf.gradients(self.cost, tvars), args.grad_clip) optimizer = tf.train.AdamOptimizer(self.lr) - self.train_op = optimizer.apply_gradients(zip(grads, tvars)) + self.train_op = optimizer.apply_gradients(zip(grads, tvars), global_step=self.global_step) def sample(self, sess, words, vocab, num=200, prime='first all', sampling_type=1): state = sess.run(self.cell.zero_state(1, tf.float32)) diff --git a/sample.py b/sample.py index bef448e8..849975ba 100644 --- a/sample.py +++ b/sample.py @@ -29,7 +29,7 @@ def sample(args): saved_args = cPickle.load(f) with open(os.path.join(args.save_dir, 'words_vocab.pkl'), 'rb') as f: words, vocab = cPickle.load(f) - model = Model(saved_args, True) + model = Model(saved_args, training=False) with tf.Session() as sess: tf.global_variables_initializer().run() saver = tf.train.Saver(tf.global_variables()) diff --git a/train.py b/train.py index 79ba44c7..04717f4d 100644 --- a/train.py +++ b/train.py @@ -36,6 +36,10 @@ def main(): help='learning rate') parser.add_argument('--decay_rate', type=float, default=0.97, help='decay rate for rmsprop') + parser.add_argument('--keep_prob', type=float, default=1.0, + help = 'probability of keeping weights in the dropout layer') + parser.add_argument('--gpu_mem', type=float, default=0.666, + help='% of gpu memory to be allocated to this process. Default is 66.6%') parser.add_argument('--init_from', type=str, default=None, help="""continue training from saved model at this path. Path must contain files saved by previous training process: 'config.pkl' : configuration; @@ -50,6 +54,7 @@ def main(): def train(args): data_loader = TextLoader(args.data_dir, args.batch_size, args.seq_length) args.vocab_size = data_loader.vocab_size + args.decay_step = data_loader.num_batches # check compatibility if training is continued from previously saved model if args.init_from is not None: @@ -83,8 +88,9 @@ def train(args): merged = tf.summary.merge_all() train_writer = tf.summary.FileWriter('logs') + gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem) - with tf.Session() as sess: + with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess: train_writer.add_graph(sess.graph) tf.global_variables_initializer().run() saver = tf.train.Saver(tf.global_variables()) @@ -92,7 +98,6 @@ def train(args): if args.init_from is not None: saver.restore(sess, ckpt.model_checkpoint_path) for e in range(model.epoch_pointer.eval(), args.num_epochs): - sess.run(tf.assign(model.lr, args.learning_rate * (args.decay_rate ** e))) data_loader.reset_batch_pointer() state = sess.run(model.initial_state) speed = 0 @@ -109,15 +114,15 @@ def train(args): x, y = data_loader.next_batch() feed = {model.input_data: x, model.targets: y, model.initial_state: state, model.batch_time: speed} - summary, train_loss, state, _, _ = sess.run([merged, model.cost, model.final_state, + summary, train_loss, lr, state, _, _ = sess.run([merged, model.cost, model.lr, model.final_state, model.train_op, model.inc_batch_pointer_op], feed) train_writer.add_summary(summary, e * data_loader.num_batches + b) speed = time.time() - start if (e * data_loader.num_batches + b) % args.batch_size == 0: - print("{}/{} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \ - .format(e * data_loader.num_batches + b, + print("{}/{} (epoch {}), lr = {:.6f}, train_loss = {:.3f}, time/batch = {:.3f}" \ + .format(e * data_loader.num_batches + b, args.num_epochs * data_loader.num_batches, - e, train_loss, speed)) + e, lr, train_loss, speed)) if (e * data_loader.num_batches + b) % args.save_every == 0 \ or (e==args.num_epochs-1 and b == data_loader.num_batches-1): # save for the last result checkpoint_path = os.path.join(args.save_dir, 'model.ckpt')