Skip to content

Commit

Permalink
update gpt
Browse files Browse the repository at this point in the history
  • Loading branch information
czyssrs committed Mar 24, 2019
1 parent 30d6074 commit 1c94012
Show file tree
Hide file tree
Showing 24 changed files with 559 additions and 143 deletions.
25 changes: 15 additions & 10 deletions DataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def __init__(self, data_dir, limits, eos, empty):
data_dir + '/valid/valid.box.rpos', data_dir + '/valid/valid_summary_field_id.txt',
data_dir + '/valid/valid_summary_pos.txt', data_dir + '/valid/valid_summary_rpos.txt']

self.vocab_mask_path = data_dir + '/vocab_200.txt'
self.vocab_mask_path = data_dir + '/vocab_local.txt'


self.limits = limits
self.man_text_len = 150
self.man_summary_len = 80
self.man_summary_len = 85
self.eos = eos
self.empty = empty
start_time = time.time()
Expand Down Expand Up @@ -72,9 +72,9 @@ def __init__(self, data_dir, limits, eos, empty):

self.gpt_out_mask[-1] = 1.0
self.target_vocab.append(eos)
assert len(self.gpt_out_mask) == eos + 1
print (len(self.gpt_out_mask))
print (len(self.target_vocab))
# assert len(self.gpt_out_mask) == eos + 1
# print (len(self.gpt_out_mask))
# print (len(self.target_vocab))


def load_data(self, path):
Expand Down Expand Up @@ -267,7 +267,7 @@ def batch_iter(self, data, batch_size, shuffle, domain):

batch_data = {'enc_in':[], 'enc_fd':[], 'enc_pos':[], 'enc_rpos':[], 'enc_len':[],
'dec_in':[], 'dec_len':[], 'dec_out':[], 'oov_map':[], 'dec_field':[],
'dec_pos':[], 'dec_rpos':[], 'gpt_context':[]}
'dec_pos':[], 'dec_rpos':[], 'gpt_context':[], 'context':[]}

for summary, text, field, pos, rpos, dec_field, dec_pos, dec_rpos in zip(summaries[start_index:end_index], texts[start_index:end_index],
fields[start_index:end_index], poses[start_index:end_index],
Expand All @@ -285,6 +285,7 @@ def batch_iter(self, data, batch_size, shuffle, domain):
assert len(dec_field) == len(summary)

gold = summary + [self.eos] * (max_summary_len - summary_len + 1)
context = [self.eos] * (max_summary_len - summary_len) + summary
summary = summary + [self.eos] * (max_summary_len - summary_len)

dec_field = dec_field + [self.empty] * (max_summary_len - summary_len)
Expand All @@ -309,18 +310,21 @@ def batch_iter(self, data, batch_size, shuffle, domain):
if max_summary_len > self.man_summary_len:
gold = gold[:self.man_summary_len + 1]
summary = summary[:self.man_summary_len]

context = context[-self.man_summary_len:]

dec_field = dec_field[:self.man_summary_len]
dec_pos = dec_pos[:self.man_summary_len]
dec_rpos = dec_rpos[:self.man_summary_len]
summary_len = min(summary_len, self.man_summary_len)


if domain == "humans":
gpt_context = "biography"
gpt_context = " Biography :"
elif domain == "books":
gpt_context = "book introduction"
elif domain == "humans_original":
gpt_context = "biography"
gpt_context = " book introduction"
elif domain == "humans_tune":
gpt_context = " Biography :"

gpt_context, _ = enc.encode(gpt_context)

Expand All @@ -343,6 +347,7 @@ def batch_iter(self, data, batch_size, shuffle, domain):
batch_data['dec_pos'].append(dec_pos)
batch_data['dec_rpos'].append(dec_rpos)
batch_data['gpt_context'].append(gpt_context)
batch_data['context'].append(context)


yield batch_data
Expand Down
128 changes: 93 additions & 35 deletions Main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


tf.app.flags.DEFINE_string("gpt_model_name",'117M','model name of gpt2')
tf.app.flags.DEFINE_string("domain",'humans','domain name')
tf.app.flags.DEFINE_string("domain",'humans_tune','domain name')

tf.app.flags.DEFINE_boolean("use_coverage", True,'use coverage or not')
tf.app.flags.DEFINE_float("coverage_penalty", 2.0,'coverage loss penalty')
Expand All @@ -44,16 +44,17 @@
tf.app.flags.DEFINE_integer("emb_size", 768, "Size of embedding.") # embedding for gpt
tf.app.flags.DEFINE_integer("field_size", 768, "Size of embedding.")
tf.app.flags.DEFINE_integer("pos_size", 5, "Size of embedding.")
tf.app.flags.DEFINE_integer("batch_size", 32, "Batch size of train set.")
tf.app.flags.DEFINE_integer("batch_size", 2, "Batch size of train set.")
tf.app.flags.DEFINE_integer("batch_update", 22, "apply gradients after steps") # multiply batch size is real batch size
tf.app.flags.DEFINE_integer("epoch", 5000, "Number of training epoch.")
tf.app.flags.DEFINE_integer("source_vocab", 50257,'vocabulary size')
tf.app.flags.DEFINE_integer("source_vocab", 6976,'vocabulary size')
tf.app.flags.DEFINE_integer("field_vocab", 2756,'vocabulary size')
tf.app.flags.DEFINE_integer("position_vocab", 31,'vocabulary size')
tf.app.flags.DEFINE_integer("target_vocab", 50257,'vocabulary size')
tf.app.flags.DEFINE_integer("report", 500,'report valid results after some steps')
tf.app.flags.DEFINE_integer("target_vocab", 6976,'vocabulary size')
tf.app.flags.DEFINE_integer("report", 100,'report valid results after some steps')
tf.app.flags.DEFINE_float("learning_rate", 0.0003,'learning rate')

tf.app.flags.DEFINE_integer("report_loss", 10,'report loss results after some steps')
tf.app.flags.DEFINE_integer("report_loss", 30,'report loss results after some steps')

FLAGS = tf.app.flags.FLAGS
last_best = 0.0
Expand All @@ -78,8 +79,10 @@
### bpe vocab
enc = encoder.get_encoder("117M")
# "<|endoftext|>": 50256
eos = 50256
empty = 28920
# eos = 50256
# empty = 28920
eos = 6975
empty = 5713


# test phase
Expand Down Expand Up @@ -123,43 +126,89 @@ def train(sess, dataloader, model):

k = 0
record_k = 0
record_loss_k = 0
loss, start_time = 0.0, time.time()
record_loss = 0.0
record_copy_loss = 0.0
record_cov_loss = 0.0


### old training
# for _ in range(FLAGS.epoch):
# for x in dataloader.batch_iter(trainset, FLAGS.batch_size, True, domain=FLAGS.domain):
# this_loss, this_copy_gate_loss, this_cov_loss = model(x, sess)
# loss += this_loss
# record_loss += this_loss
# record_copy_loss += this_copy_gate_loss
# record_cov_loss += this_cov_loss
# k += 1
# record_k += 1
# progress_bar(k % FLAGS.report, FLAGS.report)

# # ksave_dir = save_model(model, sess, save_dir, k // FLAGS.report)
# # write_log(evaluate(sess, dataloader, model, ksave_dir, 'test'))

# if (record_k % FLAGS.report_loss == 0):
# write_log("%d : loss = %.3f, copyloss = %.3f, covloss = %.3f" % \
# (k, record_loss / record_k, record_copy_loss / record_k, record_cov_loss / record_k))
# record_k = 0
# record_loss = 0.0
# record_copy_loss = 0.0
# record_cov_loss = 0.0


# if (k % FLAGS.report == 0):
# print ("Round: ", k / FLAGS.report)
# cost_time = time.time() - start_time
# write_log("%d : loss = %.3f, time = %.3f " % (k // FLAGS.report, loss, cost_time))
# loss, start_time = 0.0, time.time()
# if k // FLAGS.report >= 1:
# ksave_dir = save_model(model, sess, save_dir, k // FLAGS.report)
# write_log(evaluate(sess, dataloader, model, ksave_dir, 'valid'))
# # write_log(evaluate(sess, dataloader, model, ksave_dir, 'test'))

for _ in range(FLAGS.epoch):
for x in dataloader.batch_iter(trainset, FLAGS.batch_size, True, domain=FLAGS.domain):
this_loss, this_copy_gate_loss, this_cov_loss = model(x, sess)
loss += this_loss
record_loss += this_loss
record_copy_loss += this_copy_gate_loss
record_cov_loss += this_cov_loss
k += 1
record_k += 1
progress_bar(k % FLAGS.report, FLAGS.report)
_, _, _, _ = model(x, sess, 0)
# loss += this_loss
# record_loss += this_loss
# record_copy_loss += this_copy_gate_loss
# record_cov_loss += this_cov_loss
# k += 1
# record_k += 1
# progress_bar(record_k % FLAGS.report, FLAGS.report)

# ksave_dir = save_model(model, sess, save_dir, k // FLAGS.report)
# write_log(evaluate(sess, dataloader, model, ksave_dir, 'test'))

if (record_k % FLAGS.report_loss == 0):
write_log("%d : loss = %.3f, copyloss = %.3f, covloss = %.3f" % \
(k, record_loss / record_k, record_copy_loss / record_k, record_cov_loss / record_k))
record_k = 0
record_loss = 0.0
record_copy_loss = 0.0
record_cov_loss = 0.0


if (k % FLAGS.report == 0):
print ("Round: ", k / FLAGS.report)
cost_time = time.time() - start_time
write_log("%d : loss = %.3f, time = %.3f " % (k // FLAGS.report, loss, cost_time))
loss, start_time = 0.0, time.time()
if k // FLAGS.report >= 1:
ksave_dir = save_model(model, sess, save_dir, k // FLAGS.report)
write_log(evaluate(sess, dataloader, model, ksave_dir, 'valid'))
# write_log(evaluate(sess, dataloader, model, ksave_dir, 'test'))
k += 1
if (k % FLAGS.batch_update == 0):
this_loss, this_copy_gate_loss, this_cov_loss = model(x, sess, 1)
record_loss += this_loss
record_copy_loss += this_copy_gate_loss
record_cov_loss += this_cov_loss
record_k += 1
record_loss_k += 1


if (record_loss_k > 1 and record_loss_k % FLAGS.report_loss == 0):
write_log("%d : loss = %.3f, copyloss = %.3f, covloss = %.3f" % \
(record_k, record_loss / record_loss_k, record_copy_loss / record_loss_k, record_cov_loss / record_loss_k))
record_loss = 0.0
record_copy_loss = 0.0
record_cov_loss = 0.0
record_loss_k = 0


if (record_k > 1 and record_k % FLAGS.report == 0):
print ("Round: ", record_k / FLAGS.report)
cost_time = time.time() - start_time
write_log("%d : time = %.3f " % (record_k // FLAGS.report, cost_time))
start_time = time.time()
if record_k // FLAGS.report >= 1:
ksave_dir = save_model(model, sess, save_dir, record_k // FLAGS.report)
write_log(evaluate(sess, dataloader, model, ksave_dir, 'valid'))
# write_log(evaluate(sess, dataloader, model, ksave_dir, 'test'))



Expand Down Expand Up @@ -247,6 +296,11 @@ def main():
with open(os.path.join('../models', FLAGS.gpt_model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))

vocab_ind = []
with open(os.path.join('../models', FLAGS.gpt_model_name, 'vocab_ind.txt')) as f:
for line in f:
vocab_ind.append(int(line.strip()))

dataloader = DataLoader(processed_data_dir, FLAGS.limits, eos, empty)
field_id2word = dataloader.fieldid2word
gpt_out_mask = dataloader.gpt_out_mask
Expand All @@ -260,7 +314,8 @@ def main():
encoder_add_pos=FLAGS.encoder_pos, learning_rate=FLAGS.learning_rate,
use_coverage = FLAGS.use_coverage, coverage_penalty=FLAGS.coverage_penalty,
fieldid2word = field_id2word, copy_gate_penalty=FLAGS.copy_gate_penalty,
use_copy_gate=FLAGS.use_copy_gate, gpt_hparams=hparams, gpt_out_mask=gpt_out_mask)
use_copy_gate=FLAGS.use_copy_gate, gpt_hparams=hparams, gpt_out_mask=gpt_out_mask, vocab_ind=vocab_ind,
empty_token=empty, stop_token=eos)


if FLAGS.mode == 'train':
Expand All @@ -277,6 +332,8 @@ def main():
if "Adam" not in each_var.name:
gpt_var_load.append(each_var)

gpt_var_load.remove(model.embedding)


# print ([tmp.name for tmp in gpt_var_load])
saver = tf.train.Saver(var_list=gpt_var_load)
Expand All @@ -285,6 +342,7 @@ def main():

# ### init other vars
seq2seq_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='seq2seq')
seq2seq_var.append(model.embedding)
# print ([tmp.name for tmp in seq2seq_var])
# seq2seq_var += gpt_var_opt

Expand Down
Empty file modified README.md
100644 → 100755
Empty file.
Loading

0 comments on commit 1c94012

Please sign in to comment.