Skip to content

Commit 1d836fe

Browse files
author
u2205807031
committed
fix latent machine memory explosion problem
1 parent 9982c96 commit 1d836fe

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

dmn_plus.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,7 @@ def run_epoch(self, session, data, num_epoch=0, train_writer=None, train_op=None
276276
config = self.config
277277
dp = config.dropout
278278
if train_op is None:
279-
train_op = tf.no_op()
279+
# train_op = tf.no_op()
280280
dp = 1
281281
total_steps = len(data[0]) // config.batch_size
282282
total_loss = []
@@ -295,8 +295,13 @@ def run_epoch(self, session, data, num_epoch=0, train_writer=None, train_op=None
295295
self.input_len_placeholder: il[index],
296296
self.answer_placeholder: a[index],
297297
self.dropout_placeholder: dp}
298-
loss, pred, summary, _ = session.run(
299-
[self.calculate_loss, self.pred, self.merged, train_op], feed_dict=feed)
298+
299+
if train_op is None:
300+
loss, pred, summary, = session.run(
301+
[self.calculate_loss, self.pred, self.merged], feed_dict=feed)
302+
else:
303+
loss, pred, summary, _ = session.run(
304+
[self.calculate_loss, self.pred, self.merged, train_op], feed_dict=feed)
300305

301306
if train_writer is not None:
302307
train_writer.add_summary(summary, num_epoch*total_steps + step)

0 commit comments

Comments
 (0)