Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

start value of Discrete spaces #2054

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ New Features:

Bug Fixes:
^^^^^^^^^^
- Support for start value in discrete action spaces

`SB3-Contrib`_
^^^^^^^^^^^^^^
Expand Down Expand Up @@ -1739,4 +1740,4 @@ And all the contributors:
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger
@marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean
@brn-dev @jmacglashan @kplers
@brn-dev @jmacglashan @kplers @JoshuaBluem
31 changes: 1 addition & 30 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,6 @@ def _is_numpy_array_space(space: spaces.Space) -> bool:
return not isinstance(space, (spaces.Dict, spaces.Tuple))


def _starts_at_zero(space: Union[spaces.Discrete, spaces.MultiDiscrete]) -> bool:
"""
Return False if a (Multi)Discrete space has a non-zero start.
"""
return np.allclose(space.start, np.zeros_like(space.start))


def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None:
"""
:param space: Observation or action space
:param space_type: information about whether it is an observation or action space
(for the warning message)
:param key: When the observation space comes from a Dict space, we pass the
corresponding key to have more precise warning messages. Defaults to "".
"""
if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space):
maybe_key = f"(key='{key}')" if key else ""
warnings.warn(
f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) "
"is not supported by Stable-Baselines3. "
f"You can use a wrapper or update your {space_type} space."
)


def _check_image_input(observation_space: spaces.Box, key: str = "") -> None:
"""
Check that the input will be compatible with Stable-Baselines
Expand Down Expand Up @@ -84,10 +60,9 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act

if isinstance(observation_space, spaces.Dict):
nested_dict = False
for key, space in observation_space.spaces.items():
for _key, space in observation_space.spaces.items():
if isinstance(space, spaces.Dict):
nested_dict = True
_check_non_zero_start(space, "observation", key)

if nested_dict:
warnings.warn(
Expand Down Expand Up @@ -115,17 +90,13 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
"which is supported by SB3."
)

_check_non_zero_start(observation_space, "observation")

if isinstance(observation_space, spaces.Sequence):
warnings.warn(
"Sequence observation space is not supported by Stable-Baselines3. "
"You can pad your observation to have a fixed size instead.\n"
"Note: The checks for returned values are skipped."
)

_check_non_zero_start(action_space, "action")

if not _is_numpy_array_space(action_space):
warnings.warn(
"The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. "
Expand Down
14 changes: 12 additions & 2 deletions stable_baselines3/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,9 @@ def _sample_action(
assert self._last_obs is not None, "self._last_obs was not set"
unscaled_action, _ = self.predict(self._last_obs, deterministic=False)

# Rescale the action from [low, high] to [-1, 1]
# Transform action from its action_space bounds
if isinstance(self.action_space, spaces.Box):
# Rescale the action from [low, high] to [-1, 1]
scaled_action = self.policy.scale_action(unscaled_action)

# Add noise to the action (improve exploration)
Expand All @@ -400,10 +401,19 @@ def _sample_action(
# We store the scaled action in the buffer
buffer_action = scaled_action
action = self.policy.unscale_action(scaled_action)
elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)):
# Discrete case: Shift action values so every action starts from zero
scaled_action = self.policy.scale_action(unscaled_action)

# Use buffer action to store in buffer
buffer_action = scaled_action

# Still use the original action for env, that is scaled to its action-space bounds
action = unscaled_action
else:
# Discrete case, no need to normalize or clip
buffer_action = unscaled_action
action = buffer_action

return action, buffer_action

def _dump_logs(self) -> None:
Expand Down
17 changes: 10 additions & 7 deletions stable_baselines3/common/on_policy_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,27 @@ def collect_rollouts(

with th.no_grad():
# Convert to pytorch tensor or to TensorDict
obs_tensor = obs_as_tensor(self._last_obs, self.device)
obs_tensor = obs_as_tensor(self._last_obs, self.device) # type: ignore[assignment, arg-type]
actions, values, log_probs = self.policy(obs_tensor)
actions = actions.cpu().numpy()

# Rescale and perform action
clipped_actions = actions

# Rescale to action bounds
unscaled_actions = actions
if isinstance(self.action_space, spaces.Box):
if self.policy.squash_output:
# Unscale the actions to match env bounds
# if they were previously squashed (scaled in [-1, 1])
clipped_actions = self.policy.unscale_action(clipped_actions)
unscaled_actions = self.policy.unscale_action(unscaled_actions)
else:
# Otherwise, clip the actions to avoid out of bound error
# as we are sampling from an unbounded Gaussian distribution
clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)
unscaled_actions = np.clip(actions, self.action_space.low, self.action_space.high)
elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)):
# Scale actions to match action-space bounds
unscaled_actions = self.policy.unscale_action(unscaled_actions) # type: ignore[assignment, arg-type]

new_obs, rewards, dones, infos = env.step(clipped_actions)
# Perform step
new_obs, rewards, dones, infos = env.step(unscaled_actions)

self.num_timesteps += env.num_envs

Expand Down
53 changes: 39 additions & 14 deletions stable_baselines3/common/policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,9 @@ def predict(
# Actions could be on arbitrary scale, so clip the actions to avoid
# out of bound error (e.g. if sampling from a Gaussian distribution)
actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type]
elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)):
# transform action to its action-space bounds starting from its defined start value
actions = self.unscale_action(actions) # type: ignore[assignment, arg-type]

# Remove batch dimension if needed
if not vectorized_env:
Expand All @@ -387,30 +390,52 @@ def predict(

def scale_action(self, action: np.ndarray) -> np.ndarray:
"""
Rescale the action from [low, high] to [-1, 1]
(no need for symmetric action space)
Rescale the action from action-space bounds
Box-Case:
Scale the action from [low, high] to [-1, 1]
(no need for symmetric action space)
Discrete-Case:
Shift start value from action_space.start to zero

:param action: Action to scale
:return: Scaled action
"""
assert isinstance(
self.action_space, spaces.Box
), f"Trying to scale an action using an action space that is not a Box(): {self.action_space}"
low, high = self.action_space.low, self.action_space.high
return 2.0 * ((action - low) / (high - low)) - 1.0
scaled_action: np.ndarray

if isinstance(self.action_space, spaces.Box):
# Box case
low, high = self.action_space.low, self.action_space.high
scaled_action = 2.0 * ((action - low) / (high - low)) - 1.0
elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)):
# discrete actions case
scaled_action = np.subtract(action, self.action_space.start)
else:
raise NotImplementedError(f"Trying to scale an action using action space: {self.action_space}")

return scaled_action

def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray:
"""
Rescale the action from [-1, 1] to [low, high]
(no need for symmetric action space)
Box-Case:
Rescale the action from [-1, 1] to [low, high]
(no need for symmetric action space)
Discrete-Case:
Reverse shift start value from zero back to action_space.start

:param scaled_action: Action to un-scale
"""
assert isinstance(
self.action_space, spaces.Box
), f"Trying to unscale an action using an action space that is not a Box(): {self.action_space}"
low, high = self.action_space.low, self.action_space.high
return low + (0.5 * (scaled_action + 1.0) * (high - low))
unscaled_action: np.ndarray

if isinstance(self.action_space, spaces.Box):
low, high = self.action_space.low, self.action_space.high
unscaled_action = low + (0.5 * (scaled_action + 1.0) * (high - low))
elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)):
# match discrete actions bounds
unscaled_action = np.add(scaled_action, self.action_space.start)
else:
raise NotImplementedError(f"Trying to unscale an action using action space: {self.action_space}")

return unscaled_action


class ActorCriticPolicy(BasePolicy):
Expand Down
10 changes: 0 additions & 10 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,14 +121,8 @@ def patched_step(_action):
spaces.Dict({"position": spaces.Dict({"abs": spaces.Discrete(5), "rel": spaces.Discrete(2)})}),
# Small image inside a dict
spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}),
# Non zero start index
spaces.Discrete(3, start=-1),
# 2D MultiDiscrete
spaces.MultiDiscrete(np.array([[4, 4], [2, 3]])),
# Non zero start index (MultiDiscrete)
spaces.MultiDiscrete([4, 4], start=[1, 0]),
# Non zero start index inside a Dict
spaces.Dict({"obs": spaces.Discrete(3, start=1)}),
],
)
def test_non_default_spaces(new_obs_space):
Expand Down Expand Up @@ -166,10 +160,6 @@ def patched_step(_action):
spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32),
# Almost good, except for one dim
spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32),
# Non zero start index
spaces.Discrete(3, start=-1),
# Non zero start index (MultiDiscrete)
spaces.MultiDiscrete([4, 4], start=[1, 0]),
],
)
def test_non_default_action_spaces(new_action_space):
Expand Down
72 changes: 72 additions & 0 deletions tests/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,30 @@ def __init__(self, nvec):
)


class ActionBoundsTestClass(DummyEnv):
def step(self, action):
if isinstance(self.action_space, spaces.Discrete):
assert np.all(
action >= self.action_space.start
), f"Discrete action {action} is below the lower bound {self.action_space.start}"
assert np.all(
action <= self.action_space.start + self.action_space.n
), f"Discrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.n}"
elif isinstance(self.action_space, spaces.MultiDiscrete):
assert np.all(
action >= self.action_space.start
), f"MultiDiscrete action {action} is below the lower bound {self.action_space.start}"
assert np.all(
action <= self.action_space.start + self.action_space.nvec
), f"MultiDiscrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.nvec}"
elif isinstance(self.action_space, spaces.Box):
assert np.all(action >= self.action_space.low), f"Action {action} is below the lower bound {self.action_space.low}"
assert np.all(
action <= self.action_space.high
), f"Action {action} is above the upper bound {self.action_space.high}"
return self.observation_space.sample(), 0.0, False, False, {}


@pytest.mark.parametrize(
"env",
[
Expand Down Expand Up @@ -170,6 +194,54 @@ def test_float64_action_space(model_class, obs_space, action_space):
assert action.dtype == env.action_space.dtype


@pytest.mark.parametrize(
"model_class, action_space",
[
# on-policy test
(PPO, spaces.Discrete(5, start=-6543)),
(PPO, spaces.MultiDiscrete([4, 3], start=[-6543, 11])),
(PPO, spaces.Box(low=2344, high=2345, shape=(3,), dtype=np.float32)),
# off-policy test
(DQN, spaces.Discrete(2, start=9923)),
(SAC, spaces.Box(low=-123, high=-122, shape=(1,), dtype=np.float32)),
],
)
def test_space_bounds(model_class, action_space):
obs_space = BOX_SPACE_FLOAT32
env = ActionBoundsTestClass(obs_space, action_space)
env = gym.wrappers.TimeLimit(env, max_episode_steps=200)
if isinstance(env.observation_space, spaces.Dict):
policy = "MultiInputPolicy"
else:
policy = "MlpPolicy"

if model_class in [PPO, A2C]:
kwargs = dict(n_steps=64, policy_kwargs=dict(net_arch=[12]))
else:
kwargs = dict(learning_starts=60, policy_kwargs=dict(net_arch=[12]))

model = model_class(policy, env, **kwargs)
model.learn(64)
initial_obs, _ = env.reset()

action, _ = model.predict(initial_obs, deterministic=False)
if isinstance(action_space, spaces.Discrete):
assert np.all(action >= action_space.start), f"Discrete action {action} is below the lower bound {action_space.start}"
assert np.all(
action <= action_space.start + action_space.n
), f"Discrete action {action} is above the upper bound {action_space.start}+{action_space.n}"
elif isinstance(action_space, spaces.MultiDiscrete):
assert np.all(
action >= action_space.start
), f"MultiDiscrete action {action} is below the lower bound {action_space.start}"
assert np.all(
action <= action_space.start + action_space.nvec
), f"MultiDiscrete action {action} is above the upper bound {action_space.start}+{action_space.nvec}"
elif isinstance(action_space, spaces.Box):
assert np.all(action >= action_space.low), f"Action {action} is below the lower bound {action_space.low}"
assert np.all(action <= action_space.high), f"Action {action} is above the upper bound {action_space.high}"


def test_multidim_binary_not_supported():
env = DummyEnv(BOX_SPACE_FLOAT32, spaces.MultiBinary([2, 3]))
with pytest.raises(AssertionError, match=r"Multi-dimensional MultiBinary\(.*\) action space is not supported"):
Expand Down
Loading