Skip to content

Commit

Permalink
Call total_episode_reward_logger before incrementing num_timesteps
Browse files Browse the repository at this point in the history
Update changelog
  • Loading branch information
balintkozma committed Nov 14, 2019
1 parent a1ab7a1 commit b053747
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Bug Fixes:
- Fix seeding, so it is now possible to have deterministic results on cpu
- Fix a bug in DDPG where `predict` method with `deterministic=False` would fail
- Fix a bug in TRPO: mean_losses was not initialized causing the logger to crash when there was no gradients (@MarvineGothic)
- Fix a bug in PPO2: total_episode_reward_logger should be called before incrementing num_timesteps

Deprecations:
^^^^^^^^^^^^^
Expand Down
11 changes: 5 additions & 6 deletions stable_baselines/ppo2/ppo2.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,11 @@ def learn(self, total_timesteps, callback=None, log_interval=1, tb_log_name="PPO
cliprange_vf_now = cliprange_vf(frac)
# true_reward is the reward without discount
obs, returns, masks, actions, values, neglogpacs, states, ep_infos, true_reward = runner.run()
if writer is not None:
self.episode_reward = total_episode_reward_logger(self.episode_reward,
true_reward.reshape((self.n_envs, self.n_steps)),
masks.reshape((self.n_envs, self.n_steps)),
writer, self.num_timesteps)
self.num_timesteps += self.n_batch
ep_info_buf.extend(ep_infos)
mb_loss_vals = []
Expand Down Expand Up @@ -373,12 +378,6 @@ def learn(self, total_timesteps, callback=None, log_interval=1, tb_log_name="PPO
t_now = time.time()
fps = int(self.n_batch / (t_now - t_start))

if writer is not None:
self.episode_reward = total_episode_reward_logger(self.episode_reward,
true_reward.reshape((self.n_envs, self.n_steps)),
masks.reshape((self.n_envs, self.n_steps)),
writer, self.num_timesteps)

if self.verbose >= 1 and (update % log_interval == 0 or update == 1):
explained_var = explained_variance(values, returns)
logger.logkv("serial_timesteps", update * self.n_steps)
Expand Down

0 comments on commit b053747

Please sign in to comment.