Skip to content

Commit

Permalink
Merge branch 'master' into DEV_model_based
Browse files Browse the repository at this point in the history
revert #89
  • Loading branch information
Kismuz committed Jan 10, 2019
2 parents 92ce587 + 91db6e2 commit 0207d50
Showing 1 changed file with 6 additions and 15 deletions.
21 changes: 6 additions & 15 deletions btgym/algorithms/aac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,7 +1241,7 @@ def process_data(self, sess, data, is_train, pi, pi_prime=None):

return self._get_main_feeder(sess, on_policy_batch, off_policy_batch, rp_batch, is_train, pi, pi_prime)

def process_summary(self, sess, data, model_data=None, step=None, episode=None, run_metadata=None):
def process_summary(self, sess, data, model_data=None, step=None, episode=None):
"""
Fetches and writes summary data from `data` and `model_data`.
Args:
Expand All @@ -1250,7 +1250,6 @@ def process_summary(self, sess, data, model_data=None, step=None, episode=None,
model_data(dict): model summary data
step: int, global step or None
episode: int, global episode number or None
run_metadata(dict): model run statistics
"""
if step is None:
step = sess.run(self.global_step)
Expand Down Expand Up @@ -1321,8 +1320,7 @@ def process_summary(self, sess, data, model_data=None, step=None, episode=None,
self.summary_writer.flush()

# Every worker writes train episode summaries:
if model_data is not None and run_metadata is not None:
self.summary_writer.add_run_metadata(run_metadata, 'step%d' % step, global_step=step)
if model_data is not None:
self.summary_writer.add_summary(tf.Summary.FromString(model_data), step)
self.summary_writer.flush()

Expand Down Expand Up @@ -1383,13 +1381,13 @@ def _process(self, sess):
feed_dict = self.process_data(sess, data, is_train, self.local_network, self.local_network_prime)

# Say `No` to redundant summaries:
write_model_summary =\
wirte_model_summary =\
self.local_steps % self.model_summary_freq == 0

#fetches = [self.train_op, self.local_network.debug] # include policy debug shapes
fetches = [self.train_op]

if write_model_summary:
if wirte_model_summary:
fetches_last = fetches + [self.model_summary_op, self.inc_step]
else:
fetches_last = fetches + [self.inc_step]
Expand All @@ -1398,17 +1396,10 @@ def _process(self, sess):
# When doing more than one epoch, we actually use only last summary:
for i in range(self.num_epochs - 1):
fetched = sess.run(fetches, feed_dict=feed_dict)

run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
run_metadata = tf.RunMetadata()

fetched = sess.run(fetches_last,
feed_dict=feed_dict,
options=run_options,
run_metadata=run_metadata
)
fetched = sess.run(fetches_last, feed_dict=feed_dict)

if write_model_summary:
if wirte_model_summary:
model_summary = fetched[-2]

else:
Expand Down

0 comments on commit 0207d50

Please sign in to comment.