Skip to content

Commit

Permalink
fixed merge conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
heavani committed Apr 4, 2019
2 parents 9926140 + 18f4298 commit 742600d
Show file tree
Hide file tree
Showing 26 changed files with 1,016 additions and 322 deletions.
225 changes: 87 additions & 138 deletions DataLoader.py

Large diffs are not rendered by default.

136 changes: 98 additions & 38 deletions Main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
tf.app.flags.DEFINE_string("gpt_model_name",'117M','model name of gpt2')
tf.app.flags.DEFINE_string("domain",'humans','domain name')

tf.app.flags.DEFINE_boolean("use_coverage", True,'use coverage or not')
tf.app.flags.DEFINE_float("coverage_penalty", 2.0,'coverage loss penalty')
tf.app.flags.DEFINE_boolean("use_coverage", False,'use coverage or not')
tf.app.flags.DEFINE_float("coverage_penalty", 0.02,'coverage loss penalty')

tf.app.flags.DEFINE_boolean("use_copy_gate", True,'use copy gate or not')
tf.app.flags.DEFINE_float("copy_gate_penalty", 0.01, 'copy gate loss penalty')
tf.app.flags.DEFINE_float("copy_gate_penalty", 0.1, 'copy gate loss penalty')

tf.app.flags.DEFINE_string("mode",'train','train or test')
tf.app.flags.DEFINE_string("load",'0','load directory') # BBBBBESTOFAll
tf.app.flags.DEFINE_string("load",'0','load directory')
tf.app.flags.DEFINE_integer("limits", 0,'max data set size')

tf.app.flags.DEFINE_boolean("dual_attention", True,'dual attention layer or normal attention')
Expand All @@ -44,16 +44,17 @@
tf.app.flags.DEFINE_integer("emb_size", 768, "Size of embedding.") # embedding for gpt
tf.app.flags.DEFINE_integer("field_size", 768, "Size of embedding.")
tf.app.flags.DEFINE_integer("pos_size", 5, "Size of embedding.")
tf.app.flags.DEFINE_integer("batch_size", 32, "Batch size of train set.")
tf.app.flags.DEFINE_integer("epoch", 500, "Number of training epoch.")
tf.app.flags.DEFINE_integer("batch_size", 2, "Batch size of train set.")
tf.app.flags.DEFINE_integer("batch_update", 22, "apply gradients after steps") # multiply batch size is real batch size
tf.app.flags.DEFINE_integer("epoch", 5000, "Number of training epoch.")
tf.app.flags.DEFINE_integer("source_vocab", 50257,'vocabulary size')
tf.app.flags.DEFINE_integer("field_vocab", 2756,'vocabulary size')
tf.app.flags.DEFINE_integer("position_vocab", 31,'vocabulary size')
tf.app.flags.DEFINE_integer("target_vocab", 50257,'vocabulary size')
tf.app.flags.DEFINE_integer("report", 500,'report valid results after some steps')
tf.app.flags.DEFINE_integer("report", 50,'report valid results after some steps')
tf.app.flags.DEFINE_float("learning_rate", 0.0003,'learning rate')

tf.app.flags.DEFINE_integer("report_loss", 10,'report loss results after some steps')
tf.app.flags.DEFINE_integer("report_loss", 20,'report loss results after some steps')

FLAGS = tf.app.flags.FLAGS
last_best = 0.0
Expand All @@ -66,12 +67,12 @@


###
root_path = "../few_shot_gpt-2_data/"
root_path = "/scratch/home/zhiyu/wiki2bio/few_shot_gpt-2/"
gold_path_valid = root_path + FLAGS.domain + '/original_data/valid.summary'
gold_path_test = root_path + FLAGS.domain + '/original_data/test.summary'

field_vocab_file = root_path + "human_books_songs_films_field_vocab.txt"
vocab_file = root_path + "human_books_songs_films_word_vocab_2000.txt"
# vocab_file = root_path + "human_books_songs_films_word_vocab_2000.txt"

# word2vec_file = "/scratch/home/zhiyu/wiki2bio/other_data/glove.6B.300d.txt"
processed_data_dir = root_path + FLAGS.domain + "/processed_data"
Expand All @@ -81,6 +82,8 @@
# "<|endoftext|>": 50256
eos = 50256
empty = 28920
# eos = 6975
# empty = 5713


# test phase
Expand Down Expand Up @@ -124,44 +127,89 @@ def train(sess, dataloader, model):

k = 0
record_k = 0
record_loss_k = 0
loss, start_time = 0.0, time.time()
record_loss = 0.0
record_copy_loss = 0.0
record_cov_loss = 0.0


### old training
# for _ in range(FLAGS.epoch):
# for x in dataloader.batch_iter(trainset, FLAGS.batch_size, True, domain=FLAGS.domain):
# this_loss, this_copy_gate_loss, this_cov_loss = model(x, sess)
# loss += this_loss
# record_loss += this_loss
# record_copy_loss += this_copy_gate_loss
# record_cov_loss += this_cov_loss
# k += 1
# record_k += 1
# progress_bar(k % FLAGS.report, FLAGS.report)

# # ksave_dir = save_model(model, sess, save_dir, k // FLAGS.report)
# # write_log(evaluate(sess, dataloader, model, ksave_dir, 'test'))

# if (record_k % FLAGS.report_loss == 0):
# write_log("%d : loss = %.3f, copyloss = %.3f, covloss = %.3f" % \
# (k, record_loss / record_k, record_copy_loss / record_k, record_cov_loss / record_k))
# record_k = 0
# record_loss = 0.0
# record_copy_loss = 0.0
# record_cov_loss = 0.0


# if (k % FLAGS.report == 0):
# print ("Round: ", k / FLAGS.report)
# cost_time = time.time() - start_time
# write_log("%d : loss = %.3f, time = %.3f " % (k // FLAGS.report, loss, cost_time))
# loss, start_time = 0.0, time.time()
# if k // FLAGS.report >= 1:
# ksave_dir = save_model(model, sess, save_dir, k // FLAGS.report)
# write_log(evaluate(sess, dataloader, model, ksave_dir, 'valid'))
# # write_log(evaluate(sess, dataloader, model, ksave_dir, 'test'))

for _ in range(FLAGS.epoch):
for x in dataloader.batch_iter(trainset, FLAGS.batch_size, True, domain=FLAGS.domain):
this_loss, this_copy_gate_loss, this_cov_loss = model(x, sess)
loss += this_loss
record_loss += this_loss
record_copy_loss += this_copy_gate_loss
record_cov_loss += this_cov_loss
k += 1
record_k += 1
progress_bar(k % FLAGS.report, FLAGS.report)
_, _, _, _ = model(x, sess, 0)
# loss += this_loss
# record_loss += this_loss
# record_copy_loss += this_copy_gate_loss
# record_cov_loss += this_cov_loss
# k += 1
# record_k += 1
# progress_bar(record_k % FLAGS.report, FLAGS.report)

# ksave_dir = save_model(model, sess, save_dir, k // FLAGS.report)
# write_log(evaluate(sess, dataloader, model, ksave_dir, 'test'))

if (record_k % FLAGS.report_loss == 0):
write_log("%d : loss = %.3f, copyloss = %.3f, covloss = %.3f" % \
(k, record_loss / record_k, record_copy_loss / record_k, record_cov_loss / record_k))
record_k = 0
record_loss = 0.0
record_copy_loss = 0.0
record_cov_loss = 0.0


if (k % FLAGS.report == 0):
print("Round: ", k / FLAGS.report)
cost_time = time.time() - start_time
write_log("%d : loss = %.3f, time = %.3f " % (k // FLAGS.report, loss, cost_time))
loss, start_time = 0.0, time.time()
if k // FLAGS.report >= 1:
ksave_dir = save_model(model, sess, save_dir, k // FLAGS.report)
# write_log(evaluate(sess, dataloader, model, ksave_dir, 'valid'))
write_log(evaluate(sess, dataloader, model, ksave_dir, 'test'))

k += 1
if (k % FLAGS.batch_update == 0):
this_loss, this_copy_gate_loss, this_cov_loss = model(x, sess, 1)
record_loss += this_loss
record_copy_loss += this_copy_gate_loss
record_cov_loss += this_cov_loss
record_k += 1
record_loss_k += 1


if (record_loss_k > 1 and record_loss_k % FLAGS.report_loss == 0):
write_log("%d : loss = %.3f, copyloss = %.3f, covloss = %.3f" % \
(record_k, record_loss / record_loss_k, record_copy_loss / record_loss_k, record_cov_loss / record_loss_k))
record_loss = 0.0
record_copy_loss = 0.0
record_cov_loss = 0.0
record_loss_k = 0


if (record_k > 1 and record_k % FLAGS.report == 0):
print("Round: ", record_k / FLAGS.report)
cost_time = time.time() - start_time
write_log("%d : time = %.3f " % (record_k // FLAGS.report, cost_time))
start_time = time.time()
if record_k // FLAGS.report >= 1:
ksave_dir = save_model(model, sess, save_dir, record_k // FLAGS.report)
write_log(evaluate(sess, dataloader, model, ksave_dir, 'valid'))
# write_log(evaluate(sess, dataloader, model, ksave_dir, 'test'))


def test(sess, dataloader, model):
Expand Down Expand Up @@ -207,6 +255,8 @@ def evaluate(sess, dataloader, model, ksave_dir, mode='valid'):
real_sum = enc.decode(summary)
bpe_sum = " ".join([enc.decoder[tmp] for tmp in summary])


real_sum = real_sum.replace("\n", " ")
sw.write(real_sum + '\n')
pred_list.append(real_sum)
pred_unk.append(bpe_sum)
Expand Down Expand Up @@ -247,8 +297,14 @@ def main():
with open(os.path.join('../models', FLAGS.gpt_model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))

# vocab_ind = []
# with open(os.path.join('../models', FLAGS.gpt_model_name, 'vocab_ind.txt')) as f:
# for line in f:
# vocab_ind.append(int(line.strip()))

dataloader = DataLoader(processed_data_dir, FLAGS.limits, eos, empty)
field_id2word = dataloader.fieldid2word
gpt_out_mask = dataloader.gpt_out_mask

model = SeqUnit(batch_size=FLAGS.batch_size, hidden_size=FLAGS.hidden_size, emb_size=FLAGS.emb_size,
field_size=FLAGS.field_size, pos_size=FLAGS.pos_size, field_vocab=FLAGS.field_vocab,
Expand All @@ -259,7 +315,8 @@ def main():
encoder_add_pos=FLAGS.encoder_pos, learning_rate=FLAGS.learning_rate,
use_coverage = FLAGS.use_coverage, coverage_penalty=FLAGS.coverage_penalty,
fieldid2word = field_id2word, copy_gate_penalty=FLAGS.copy_gate_penalty,
use_copy_gate=FLAGS.use_copy_gate, gpt_hparams=hparams)
use_copy_gate=FLAGS.use_copy_gate, gpt_hparams=hparams, gpt_out_mask=gpt_out_mask, vocab_ind=None,
empty_token=empty, stop_token=eos)


if FLAGS.mode == 'train':
Expand All @@ -276,6 +333,8 @@ def main():
if "Adam" not in each_var.name:
gpt_var_load.append(each_var)

gpt_var_load.remove(model.embedding)


# print ([tmp.name for tmp in gpt_var_load])
saver = tf.train.Saver(var_list=gpt_var_load)
Expand All @@ -284,6 +343,7 @@ def main():

# ### init other vars
seq2seq_var = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='seq2seq')
seq2seq_var.append(model.embedding)
# print ([tmp.name for tmp in seq2seq_var])
# seq2seq_var += gpt_var_opt

Expand Down
48 changes: 48 additions & 0 deletions OutputUnit_gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 17-4-27 下午8:36
# @Author : Tianyu Liu

import tensorflow as tf
import pickle


class OutputUnit_gpt(object):
def __init__(self, input_size, output_size, project_init, scope_name):
self.input_size = input_size
self.output_size = output_size
self.scope_name = scope_name
self.params = {}

with tf.variable_scope(scope_name):
# self.W = tf.get_variable('W', [input_size, output_size])
# self.b = tf.get_variable('b', [output_size], initializer=tf.zeros_initializer(), dtype=tf.float32)

### use gpt word embedding init
self.W = tf.get_variable('W', initializer=project_init)

# self.params.update({'W': self.W, 'b': self.b})
self.params.update({'W': self.W})

def __call__(self, x, finished = None):
# out = tf.nn.xw_plus_b(x, self.W, self.b)

### use gpt word embedding init
out = tf.matmul(x, self.W)

if finished is not None:
out = tf.where(finished, tf.zeros_like(out), out)
#out = tf.multiply(1 - finished, out)
return out

def save(self, path):
param_values = {}
for param in self.params:
param_values[param] = self.params[param].eval()
with open(path, 'wb') as f:
pickle.dump(param_values, f, True)

def load(self, path):
param_values = pickle.load(open(path, 'rb'))
for param in param_values:
self.params[param].load(param_values[param])
Empty file modified README.md
100644 → 100755
Empty file.
Loading

0 comments on commit 742600d

Please sign in to comment.