Skip to content

Commit

Permalink
entrypoint variable made public (#970) and Fix RuntimeError (#910) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
solaris33 authored and pzhokhov committed Nov 8, 2019
1 parent 517433f commit b99a73a
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions baselines/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
_game_envs = defaultdict(set)
for env in gym.envs.registry.all():
# TODO: solve this with regexes
env_type = env._entry_point.split(':')[0].split('.')[-1]
env_type = env.entry_point.split(':')[0].split('.')[-1]
_game_envs[env_type].add(env.id)

# reading benchmark names directly from retro requires
Expand Down Expand Up @@ -119,7 +119,7 @@ def get_env_type(args):

# Re-parse the gym registry, since we could have new envs since last time.
for env in gym.envs.registry.all():
env_type = env._entry_point.split(':')[0].split('.')[-1]
env_type = env.entry_point.split(':')[0].split('.')[-1]
_game_envs[env_type].add(env.id) # This is a set so add is idempotent

if env_id in _game_envs.keys():
Expand Down Expand Up @@ -222,7 +222,7 @@ def main(args):

state = model.initial_state if hasattr(model, 'initial_state') else None

episode_rew = 0
episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv) else np.zeros(1)
while True:
if state is not None:
actions, _, state, _ = model.step(obs)
Expand All @@ -232,13 +232,13 @@ def main(args):
obs, rew, done, _ = env.step(actions.numpy())
if not isinstance(env, VecEnv):
obs = np.expand_dims(np.array(obs), axis=0)
episode_rew += rew[0] if isinstance(env, VecEnv) else rew
episode_rew += rew
env.render()
done = done.any() if isinstance(done, np.ndarray) else done
if done:
print('episode_rew={}'.format(episode_rew))
episode_rew = 0
obs = env.reset()
done_any = done.any() if isinstance(done, np.ndarray) else done
if done_any:
for i in np.nonzero(done)[0]:
print('episode_rew={}'.format(episode_rew[i]))
episode_rew[i] = 0

env.close()

Expand Down

0 comments on commit b99a73a

Please sign in to comment.