@@ -35,10 +35,6 @@ class Config(object):
35
35
word2vec_init = False
36
36
embedding_init = np .sqrt (3 )
37
37
38
- # set to zero with strong supervision to only train gates
39
- strong_supervision = False
40
- beta = 1
41
-
42
38
# NOTE not currently used hence non-sensical anneal_threshold
43
39
anneal_threshold = 1000
44
40
anneal_by = 1.5
@@ -84,9 +80,9 @@ class DMN_PLUS(object):
84
80
def load_data (self , debug = False ):
85
81
"""Loads train/valid/test data and sentence encoding"""
86
82
if self .config .train_mode :
87
- self .train , self .valid , self .word_embedding , self .max_q_len , self .max_sentences , self .max_sen_len , self .num_supporting_facts , self . vocab_size = babi_input .load_babi (self .config , split_sentences = True )
83
+ self .train , self .valid , self .word_embedding , self .max_q_len , self .max_sentences , self .max_sen_len , self .vocab_size = babi_input .load_babi (self .config , split_sentences = True )
88
84
else :
89
- self .test , self .word_embedding , self .max_q_len , self .max_sentences , self .max_sen_len , self .num_supporting_facts , self . vocab_size = babi_input .load_babi (self .config , split_sentences = True )
85
+ self .test , self .word_embedding , self .max_q_len , self .max_sentences , self .max_sen_len , self .vocab_size = babi_input .load_babi (self .config , split_sentences = True )
90
86
self .encoding = _position_encoding (self .max_sen_len , self .config .embed_size )
91
87
92
88
def add_placeholders (self ):
@@ -99,9 +95,6 @@ def add_placeholders(self):
99
95
100
96
self .answer_placeholder = tf .placeholder (tf .int64 , shape = (self .config .batch_size ,))
101
97
102
- # fact corresponding to answer. Useful for strong supervision
103
- self .rel_label_placeholder = tf .placeholder (tf .int32 , shape = (self .config .batch_size , self .num_supporting_facts ))
104
-
105
98
self .dropout_placeholder = tf .placeholder (tf .float32 )
106
99
107
100
def get_predictions (self , output ):
@@ -111,14 +104,7 @@ def get_predictions(self, output):
111
104
112
105
def add_loss_op (self , output ):
113
106
"""Calculate loss"""
114
- # optional strong supervision of attention with supporting facts
115
- gate_loss = 0
116
- if self .config .strong_supervision :
117
- for i , att in enumerate (self .attentions ):
118
- labels = tf .gather (tf .transpose (self .rel_label_placeholder ), 0 )
119
- gate_loss += tf .reduce_sum (tf .nn .sparse_softmax_cross_entropy_with_logits (logits = att , labels = labels ))
120
-
121
- loss = self .config .beta * tf .reduce_sum (tf .nn .sparse_softmax_cross_entropy_with_logits (logits = output , labels = self .answer_placeholder )) + gate_loss
107
+ loss = tf .reduce_sum (tf .nn .sparse_softmax_cross_entropy_with_logits (logits = output , labels = self .answer_placeholder ))
122
108
123
109
# add l2 regularization for all variables except biases
124
110
for v in tf .trainable_variables ():
@@ -298,8 +284,8 @@ def run_epoch(self, session, data, num_epoch=0, train_writer=None, train_op=None
298
284
299
285
# shuffle data
300
286
p = np .random .permutation (len (data [0 ]))
301
- qp , ip , ql , il , im , a , r = data
302
- qp , ip , ql , il , im , a , r = qp [p ], ip [p ], ql [p ], il [p ], im [p ], a [p ], r [ p ]
287
+ qp , ip , ql , il , im , a = data
288
+ qp , ip , ql , il , im , a = qp [p ], ip [p ], ql [p ], il [p ], im [p ], a [p ]
303
289
304
290
for step in range (total_steps ):
305
291
index = range (step * config .batch_size ,(step + 1 )* config .batch_size )
@@ -308,7 +294,6 @@ def run_epoch(self, session, data, num_epoch=0, train_writer=None, train_op=None
308
294
self .question_len_placeholder : ql [index ],
309
295
self .input_len_placeholder : il [index ],
310
296
self .answer_placeholder : a [index ],
311
- self .rel_label_placeholder : r [index ],
312
297
self .dropout_placeholder : dp }
313
298
loss , pred , summary , _ = session .run (
314
299
[self .calculate_loss , self .pred , self .merged , train_op ], feed_dict = feed )
0 commit comments