Skip to content

Commit 9982c96

Browse files
author
Alex Barron
committed
remove strong supervision
1 parent 1ab411f commit 9982c96

File tree

2 files changed

+12
-35
lines changed

2 files changed

+12
-35
lines changed

babi_input.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ def process_input(data_raw, floatX, word2vec, vocab, ivocab, embed_size, split_s
141141
inputs = []
142142
answers = []
143143
input_masks = []
144-
relevant_labels = []
145144
for x in data_raw:
146145
if split_sentences:
147146
inp = x["C"].lower().split(' . ')
@@ -197,9 +196,7 @@ def process_input(data_raw, floatX, word2vec, vocab, ivocab, embed_size, split_s
197196
else:
198197
raise Exception("invalid input_mask_mode")
199198

200-
relevant_labels.append(x["S"])
201-
202-
return inputs, questions, answers, input_masks, relevant_labels
199+
return inputs, questions, answers, input_masks
203200

204201
def get_lens(inputs, split_sentences=False):
205202
lens = np.zeros((len(inputs)), dtype=int)
@@ -280,7 +277,7 @@ def load_babi(config, split_sentences=False):
280277
else:
281278
word_embedding = np.random.uniform(-config.embedding_init, config.embedding_init, (len(ivocab), config.embed_size))
282279

283-
inputs, questions, answers, input_masks, rel_labels = train_data if config.train_mode else test_data
280+
inputs, questions, answers, input_masks = train_data if config.train_mode else test_data
284281

285282
if split_sentences:
286283
input_lens, sen_lens, max_sen_len = get_sentence_lens(inputs)
@@ -307,17 +304,12 @@ def load_babi(config, split_sentences=False):
307304

308305
answers = np.stack(answers)
309306

310-
rel_labels = np.array(rel_labels)
311-
312307
if config.train_mode:
313-
train = questions[:config.num_train], inputs[:config.num_train], q_lens[:config.num_train], input_lens[:config.num_train], input_masks[:config.num_train], answers[:config.num_train], rel_labels[:config.num_train]
308+
train = questions[:config.num_train], inputs[:config.num_train], q_lens[:config.num_train], input_lens[:config.num_train], input_masks[:config.num_train], answers[:config.num_train]
314309

315-
valid = questions[config.num_train:], inputs[config.num_train:], q_lens[config.num_train:], input_lens[config.num_train:], input_masks[config.num_train:], answers[config.num_train:], rel_labels[config.num_train:]
316-
return train, valid, word_embedding, max_q_len, max_input_len, max_mask_len, rel_labels.shape[1], len(vocab)
310+
valid = questions[config.num_train:], inputs[config.num_train:], q_lens[config.num_train:], input_lens[config.num_train:], input_masks[config.num_train:], answers[config.num_train:]
311+
return train, valid, word_embedding, max_q_len, max_input_len, max_mask_len, len(vocab)
317312

318313
else:
319-
test = questions, inputs, q_lens, input_lens, input_masks, answers, rel_labels
320-
return test, word_embedding, max_q_len, max_input_len, max_mask_len, rel_labels.shape[1], len(vocab)
321-
322-
323-
314+
test = questions, inputs, q_lens, input_lens, input_masks, answers
315+
return test, word_embedding, max_q_len, max_input_len, max_mask_len, len(vocab)

dmn_plus.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@ class Config(object):
3535
word2vec_init = False
3636
embedding_init = np.sqrt(3)
3737

38-
# set to zero with strong supervision to only train gates
39-
strong_supervision = False
40-
beta = 1
41-
4238
# NOTE not currently used hence non-sensical anneal_threshold
4339
anneal_threshold = 1000
4440
anneal_by = 1.5
@@ -84,9 +80,9 @@ class DMN_PLUS(object):
8480
def load_data(self, debug=False):
8581
"""Loads train/valid/test data and sentence encoding"""
8682
if self.config.train_mode:
87-
self.train, self.valid, self.word_embedding, self.max_q_len, self.max_sentences, self.max_sen_len, self.num_supporting_facts, self.vocab_size = babi_input.load_babi(self.config, split_sentences=True)
83+
self.train, self.valid, self.word_embedding, self.max_q_len, self.max_sentences, self.max_sen_len, self.vocab_size = babi_input.load_babi(self.config, split_sentences=True)
8884
else:
89-
self.test, self.word_embedding, self.max_q_len, self.max_sentences, self.max_sen_len, self.num_supporting_facts, self.vocab_size = babi_input.load_babi(self.config, split_sentences=True)
85+
self.test, self.word_embedding, self.max_q_len, self.max_sentences, self.max_sen_len, self.vocab_size = babi_input.load_babi(self.config, split_sentences=True)
9086
self.encoding = _position_encoding(self.max_sen_len, self.config.embed_size)
9187

9288
def add_placeholders(self):
@@ -99,9 +95,6 @@ def add_placeholders(self):
9995

10096
self.answer_placeholder = tf.placeholder(tf.int64, shape=(self.config.batch_size,))
10197

102-
# fact corresponding to answer. Useful for strong supervision
103-
self.rel_label_placeholder = tf.placeholder(tf.int32, shape=(self.config.batch_size, self.num_supporting_facts))
104-
10598
self.dropout_placeholder = tf.placeholder(tf.float32)
10699

107100
def get_predictions(self, output):
@@ -111,14 +104,7 @@ def get_predictions(self, output):
111104

112105
def add_loss_op(self, output):
113106
"""Calculate loss"""
114-
# optional strong supervision of attention with supporting facts
115-
gate_loss = 0
116-
if self.config.strong_supervision:
117-
for i, att in enumerate(self.attentions):
118-
labels = tf.gather(tf.transpose(self.rel_label_placeholder), 0)
119-
gate_loss += tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=att, labels=labels))
120-
121-
loss = self.config.beta*tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=output, labels=self.answer_placeholder)) + gate_loss
107+
loss = tf.reduce_sum(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=output, labels=self.answer_placeholder))
122108

123109
# add l2 regularization for all variables except biases
124110
for v in tf.trainable_variables():
@@ -298,8 +284,8 @@ def run_epoch(self, session, data, num_epoch=0, train_writer=None, train_op=None
298284

299285
# shuffle data
300286
p = np.random.permutation(len(data[0]))
301-
qp, ip, ql, il, im, a, r = data
302-
qp, ip, ql, il, im, a, r = qp[p], ip[p], ql[p], il[p], im[p], a[p], r[p]
287+
qp, ip, ql, il, im, a = data
288+
qp, ip, ql, il, im, a = qp[p], ip[p], ql[p], il[p], im[p], a[p]
303289

304290
for step in range(total_steps):
305291
index = range(step*config.batch_size,(step+1)*config.batch_size)
@@ -308,7 +294,6 @@ def run_epoch(self, session, data, num_epoch=0, train_writer=None, train_op=None
308294
self.question_len_placeholder: ql[index],
309295
self.input_len_placeholder: il[index],
310296
self.answer_placeholder: a[index],
311-
self.rel_label_placeholder: r[index],
312297
self.dropout_placeholder: dp}
313298
loss, pred, summary, _ = session.run(
314299
[self.calculate_loss, self.pred, self.merged, train_op], feed_dict=feed)

0 commit comments

Comments
 (0)