Skip to content

Commit

Permalink
Forward __getattr__ calls in atari wrappers.
Browse files Browse the repository at this point in the history
Resolves issue #26.

PiperOrigin-RevId: 318244490
Change-Id: Id2e37e2d2032e738f97e45939cf2d2ed9b73d9b7
  • Loading branch information
aslanides authored and copybara-github committed Jun 25, 2020
1 parent 19a0a46 commit bd963d9
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 16 deletions.
20 changes: 7 additions & 13 deletions acme/wrappers/atari_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from typing import Tuple, List, Optional, Sequence, Union

from acme import types
from acme.wrappers import base
from acme.wrappers import frame_stacking
import dm_env
from dm_env import specs
Expand All @@ -29,7 +29,7 @@
NUM_COLOR_CHANNELS = 3 # Number of color channels in RGB data.


class AtariWrapper(dm_env.Environment):
class AtariWrapper(base.EnvironmentWrapper):
"""Standard "Nature Atari" wrapper for Python environments.
This assumes that the input environment is a dm_env.Environment instance in
Expand Down Expand Up @@ -107,9 +107,9 @@ def __init__(self,
pooled_frames, action_repeats))

if zero_discount_on_life_loss:
self._environment = _ZeroDiscountOnLifeLoss(environment)
super().__init__(_ZeroDiscountOnLifeLoss(environment))
else:
self._environment = environment
super().__init__(environment)

if not max_episode_len:
max_episode_len = np.inf
Expand Down Expand Up @@ -336,7 +336,7 @@ def raw_observation(self) -> np.ndarray:
return self._raw_observation


class _ZeroDiscountOnLifeLoss(dm_env.Environment):
class _ZeroDiscountOnLifeLoss(base.EnvironmentWrapper):
"""Implements soft-termination (zero discount) on life loss."""

def __init__(self, environment: dm_env.Environment):
Expand All @@ -348,7 +348,7 @@ def __init__(self, environment: dm_env.Environment):
Raises:
ValueError: If the environment does not expose a lives observation.
"""
self._env = environment
super().__init__(environment)
self._reset_next_step = True
self._last_num_lives = None

Expand All @@ -362,7 +362,7 @@ def step(self, action: int) -> dm_env.TimeStep:
if self._reset_next_step:
return self.reset()

timestep = self._env.step(action)
timestep = self._environment.step(action)
lives = timestep.observation[LIVES_INDEX]

is_life_loss = True
Expand All @@ -376,9 +376,3 @@ def step(self, action: int) -> dm_env.TimeStep:
if is_life_loss:
return timestep._replace(discount=0.0)
return timestep

def observation_spec(self) -> types.NestedSpec:
return self._env.observation_spec()

def action_spec(self) -> specs.DiscreteArray:
return self._env.action_spec()
3 changes: 3 additions & 0 deletions acme/wrappers/atari_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def test_pong(self):
self.assertEqual(action_spec.num_values, 18)
self.assertEqual(action_spec.dtype, np.dtype('int32'))

# Check that the `render` call gets delegated to the underlying Gym env.
env.render('rgb_array')

# Test step.
timestep = env.reset()
self.assertTrue(timestep.first())
Expand Down
8 changes: 6 additions & 2 deletions acme/wrappers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,14 @@ class EnvironmentWrapper(dm_env.Environment):
wrapped environment (and hence enabling duck-typing).
"""

_environment: dm_env.Environment

def __init__(self, environment: dm_env.Environment):
self._environment = environment

def __getattr__(self, name):
return getattr(self._environment, name)
def __getattr__(self, attr: str):
# Delegates attribute calls to the wrapped environment.
return getattr(self._environment, attr)

@property
def environment(self) -> dm_env.Environment:
Expand Down Expand Up @@ -66,6 +69,7 @@ def wrap_all(
environment: dm_env.Environment,
wrappers: Sequence[Callable[[dm_env.Environment], dm_env.Environment]],
) -> dm_env.Environment:
"""Given an environment, wrap it in a list of wrappers."""
for w in wrappers:
environment = w(environment)

Expand Down
2 changes: 1 addition & 1 deletion acme/wrappers/gym_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def environment(self) -> gym.Env:
"""Returns the wrapped environment."""
return self._environment

def __getattr__(self, name):
def __getattr__(self, name: str):
# Expose any other attributes of the underlying environment.
return getattr(self._environment, name)

Expand Down

0 comments on commit bd963d9

Please sign in to comment.