-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathgenerate.py
117 lines (91 loc) · 4.59 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import numpy as np
import tensorflow as tf
import model
from data_reader import load_data, DataReader
flags = tf.flags
# data
flags.DEFINE_string('load_model', None, 'filename of the model to load')
# we need data only to compute vocabulary
flags.DEFINE_string('data_dir', 'data', 'data directory')
flags.DEFINE_integer('num_samples', 300, 'how many words to generate')
flags.DEFINE_float('temperature', 1.0, 'sampling temperature')
# model params
flags.DEFINE_integer('rnn_size', 650, 'size of LSTM internal state')
flags.DEFINE_integer('highway_layers', 2, 'number of highway layers')
flags.DEFINE_integer('char_embed_size', 15, 'dimensionality of character embeddings')
flags.DEFINE_string ('kernels', '[1,2,3,4,5,6,7]', 'CNN kernel widths')
flags.DEFINE_string ('kernel_features', '[50,100,150,200,200,200,200]', 'number of features in the CNN kernel')
flags.DEFINE_integer('rnn_layers', 2, 'number of layers in the LSTM')
flags.DEFINE_float ('dropout', 0.5, 'dropout. 0 = no dropout')
# optimization
flags.DEFINE_integer('max_word_length', 65, 'maximum word length')
# bookkeeping
flags.DEFINE_integer('seed', 3435, 'random number generator seed')
flags.DEFINE_string ('EOS', '+', '<EOS> symbol. should be a single unused character (like +) for PTB and blank for others')
FLAGS = flags.FLAGS
def main(_):
''' Loads trained model and evaluates it on test split '''
if FLAGS.load_model is None:
print('Please specify checkpoint file to load model from')
return -1
if not os.path.exists(FLAGS.load_model):
print('Checkpoint file not found', FLAGS.load_model)
return -1
word_vocab, char_vocab, word_tensors, char_tensors, max_word_length = \
load_data(FLAGS.data_dir, FLAGS.max_word_length, eos=FLAGS.EOS)
print('initialized test dataset reader')
with tf.Graph().as_default(), tf.Session() as session:
# tensorflow seed must be inside graph
tf.set_random_seed(FLAGS.seed)
np.random.seed(seed=FLAGS.seed)
''' build inference graph '''
with tf.variable_scope("Model"):
m = model.inference_graph(
char_vocab_size=char_vocab.size,
word_vocab_size=word_vocab.size,
char_embed_size=FLAGS.char_embed_size,
batch_size=1,
num_highway_layers=FLAGS.highway_layers,
num_rnn_layers=FLAGS.rnn_layers,
rnn_size=FLAGS.rnn_size,
max_word_length=max_word_length,
kernels=eval(FLAGS.kernels),
kernel_features=eval(FLAGS.kernel_features),
num_unroll_steps=1,
dropout=0)
# we need global step only because we want to read it from the model
global_step = tf.Variable(0, dtype=tf.int32, name='global_step')
saver = tf.train.Saver()
saver.restore(session, FLAGS.load_model)
print('Loaded model from', FLAGS.load_model, 'saved at global step', global_step.eval())
''' training starts here '''
rnn_state = session.run(m.initial_rnn_state)
logits = np.ones((word_vocab.size,))
rnn_state = session.run(m.initial_rnn_state)
for i in range(FLAGS.num_samples):
logits = logits / FLAGS.temperature
prob = np.exp(logits)
prob /= np.sum(prob)
prob = prob.ravel()
ix = np.random.choice(range(len(prob)), p=prob)
word = word_vocab.token(ix)
if word == '|': # EOS
print('<unk>', end=' ')
elif word == '+':
print('\n')
else:
print(word, end=' ')
char_input = np.zeros((1, 1, max_word_length))
for i,c in enumerate('{' + word + '}'):
char_input[0,0,i] = char_vocab[c]
logits, rnn_state = session.run([m.logits, m.final_rnn_state],
{m.input: char_input,
m.initial_rnn_state: rnn_state})
logits = np.array(logits)
if __name__ == "__main__":
tf.app.run()