From c0aa574b6cdcdb012988973c45674020a214285e Mon Sep 17 00:00:00 2001 From: JoshuaBluem Date: Fri, 6 Dec 2024 04:09:43 +0100 Subject: [PATCH 1/5] Fixed start value of Discrete Action space The start value of Discrete Action-Space had no effect previously and it is now supported --- .../common/off_policy_algorithm.py | 14 ++++- .../common/on_policy_algorithm.py | 17 +++--- stable_baselines3/common/policies.py | 53 ++++++++++++++----- 3 files changed, 61 insertions(+), 23 deletions(-) diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 6a043e7ac..501625eec 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -389,8 +389,9 @@ def _sample_action( assert self._last_obs is not None, "self._last_obs was not set" unscaled_action, _ = self.predict(self._last_obs, deterministic=False) - # Rescale the action from [low, high] to [-1, 1] + # Transform action from its action_space bounds if isinstance(self.action_space, spaces.Box): + # Rescale the action from [low, high] to [-1, 1] scaled_action = self.policy.scale_action(unscaled_action) # Add noise to the action (improve exploration) @@ -400,10 +401,19 @@ def _sample_action( # We store the scaled action in the buffer buffer_action = scaled_action action = self.policy.unscale_action(scaled_action) + elif isinstance(self.action_space, spaces.Discrete): + # Discrete case: Shift action values so every action starts from zero + scaled_action = self.policy.scale_action(unscaled_action) + + # Use buffer action to store in buffer + buffer_action = scaled_action + + # Still use the original action for env, that is scaled to its action-space bounds + action = unscaled_action else: - # Discrete case, no need to normalize or clip buffer_action = unscaled_action action = buffer_action + return action, buffer_action def _dump_logs(self) -> None: diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index ac4c0970c..0e1ec1cc2 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -198,24 +198,27 @@ def collect_rollouts( with th.no_grad(): # Convert to pytorch tensor or to TensorDict - obs_tensor = obs_as_tensor(self._last_obs, self.device) + obs_tensor = obs_as_tensor(self._last_obs, self.device) # type: ignore[assignment, arg-type] actions, values, log_probs = self.policy(obs_tensor) actions = actions.cpu().numpy() - # Rescale and perform action - clipped_actions = actions - + # Rescale to action bounds + unscaled_actions = actions if isinstance(self.action_space, spaces.Box): if self.policy.squash_output: # Unscale the actions to match env bounds # if they were previously squashed (scaled in [-1, 1]) - clipped_actions = self.policy.unscale_action(clipped_actions) + unscaled_actions = self.policy.unscale_action(unscaled_actions) else: # Otherwise, clip the actions to avoid out of bound error # as we are sampling from an unbounded Gaussian distribution - clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high) + unscaled_actions = np.clip(actions, self.action_space.low, self.action_space.high) + elif isinstance(self.action_space, spaces.Discrete): + # Scale actions to match action-space bounds + unscaled_actions = self.policy.unscale_action(unscaled_actions) # type: ignore[assignment, arg-type] - new_obs, rewards, dones, infos = env.step(clipped_actions) + # Perform step + new_obs, rewards, dones, infos = env.step(unscaled_actions) self.num_timesteps += env.num_envs diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index e20256f0c..9a42d2371 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -377,6 +377,9 @@ def predict( # Actions could be on arbitrary scale, so clip the actions to avoid # out of bound error (e.g. if sampling from a Gaussian distribution) actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type] + elif isinstance(self.action_space, spaces.Discrete): + # transform action to its action-space bounds starting from its defined start value + actions = self.unscale_action(actions) # type: ignore[assignment, arg-type] # Remove batch dimension if needed if not vectorized_env: @@ -387,30 +390,52 @@ def predict( def scale_action(self, action: np.ndarray) -> np.ndarray: """ - Rescale the action from [low, high] to [-1, 1] - (no need for symmetric action space) + Rescale the action from action-space bounds + Box-Case: + Scale the action from [low, high] to [-1, 1] + (no need for symmetric action space) + Discrete-Case: + Shift start value from action_space.start to zero :param action: Action to scale :return: Scaled action """ - assert isinstance( - self.action_space, spaces.Box - ), f"Trying to scale an action using an action space that is not a Box(): {self.action_space}" - low, high = self.action_space.low, self.action_space.high - return 2.0 * ((action - low) / (high - low)) - 1.0 + scaled_action: np.ndarray + + if isinstance(self.action_space, spaces.Box): + # Box case + low, high = self.action_space.low, self.action_space.high + scaled_action = 2.0 * ((action - low) / (high - low)) - 1.0 + elif isinstance(self.action_space, spaces.Discrete): + # discrete actions case + scaled_action = np.subtract(action, self.action_space.start) + else: + raise AssertionError(f"Trying to scale an action using an action space that is not a Box() or Discrete(): {self.action_space}") + + return scaled_action def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray: """ - Rescale the action from [-1, 1] to [low, high] - (no need for symmetric action space) + Box-Case: + Rescale the action from [-1, 1] to [low, high] + (no need for symmetric action space) + Discrete-Case: + Reverse shift start value from zero back to action_space.start :param scaled_action: Action to un-scale """ - assert isinstance( - self.action_space, spaces.Box - ), f"Trying to unscale an action using an action space that is not a Box(): {self.action_space}" - low, high = self.action_space.low, self.action_space.high - return low + (0.5 * (scaled_action + 1.0) * (high - low)) + unscaled_action: np.ndarray + + if isinstance(self.action_space, spaces.Box): + low, high = self.action_space.low, self.action_space.high + unscaled_action = low + (0.5 * (scaled_action + 1.0) * (high - low)) + elif isinstance(self.action_space, spaces.Discrete): + # match discrete actions bounds + unscaled_action = np.add(scaled_action, self.action_space.start) + else: + raise AssertionError(f"Trying to unscale an action using an action space that is not a Box() or Discrete(): {self.action_space}") + + return unscaled_action class ActorCriticPolicy(BasePolicy): From d01691e8b2e2087797039379995dd8eda7fb52a0 Mon Sep 17 00:00:00 2001 From: JoshuaB <72627997+JoshuaBluem@users.noreply.github.com> Date: Fri, 6 Dec 2024 05:56:13 +0100 Subject: [PATCH 2/5] Added test cases for start value Added tests to the pylint checks to confirm the correct action bounds of predict and step method --- tests/test_spaces.py | 47 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index cd38e1ecd..6ceb88db9 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -53,6 +53,15 @@ def __init__(self, nvec): BOX_SPACE_FLOAT32, ) +class ActionBoundsTestClass(DummyEnv): + def step(self, action): + if isinstance(self.action_space, spaces.Discrete): + assert np.all(action >= self.action_space.start), f"Discrete action {action} is below the lower bound {self.action_space.start}" + assert np.all(action <= self.action_space.start + self.action_space.n), f"Discrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.n}" + elif isinstance(self.action_space, spaces.Box): + assert np.all(action >= self.action_space.low), f"Action {action} is below the lower bound {self.action_space.low}" + assert np.all(action <= self.action_space.high), f"Action {action} is above the upper bound {self.action_space.high}" + return self.observation_space.sample(), 0.0, False, False, {} @pytest.mark.parametrize( "env", @@ -170,6 +179,44 @@ def test_float64_action_space(model_class, obs_space, action_space): assert action.dtype == env.action_space.dtype +@pytest.mark.parametrize( + "model_class, action_space", + [ + # on-policy test + (PPO, spaces.Discrete(5, start=-6543)), + (PPO, spaces.Box(low=2344, high=2345, shape=(3,), dtype=np.float32)), + # off-policy test + (DQN, spaces.Discrete(2, start=9923)), + (SAC, spaces.Box(low=-123, high=-122, shape=(1,), dtype=np.float32)), + ], + ) +def test_space_bounds(model_class, action_space): + obs_space = BOX_SPACE_FLOAT32 + env = ActionBoundsTestClass(obs_space, action_space) + env = gym.wrappers.TimeLimit(env, max_episode_steps=200) + if isinstance(env.observation_space, spaces.Dict): + policy = "MultiInputPolicy" + else: + policy = "MlpPolicy" + + if model_class in [PPO, A2C]: + kwargs = dict(n_steps=64, policy_kwargs=dict(net_arch=[12])) + else: + kwargs = dict(learning_starts=60, policy_kwargs=dict(net_arch=[12])) + + model = model_class(policy, env, **kwargs) + model.learn(64) + initial_obs, _ = env.reset() + + action, _ = model.predict(initial_obs, deterministic=False) + if isinstance(action_space, spaces.Discrete): + assert np.all(action >= action_space.start), f"Discrete action {action} is below the lower bound {action_space.start}" + assert np.all(action <= action_space.start + action_space.n), f"Discrete action {action} is above the upper bound {action_space.start}+{action_space.n}" + elif isinstance(action_space, spaces.Box): + assert np.all(action >= action_space.low), f"Action {action} is below the lower bound {action_space.low}" + assert np.all(action <= action_space.high), f"Action {action} is above the upper bound {action_space.high}" + + def test_multidim_binary_not_supported(): env = DummyEnv(BOX_SPACE_FLOAT32, spaces.MultiBinary([2, 3])) with pytest.raises(AssertionError, match=r"Multi-dimensional MultiBinary\(.*\) action space is not supported"): From 9851d0ec28871d2c8dd99297555f170ac413e72d Mon Sep 17 00:00:00 2001 From: JoshuaBluem Date: Fri, 6 Dec 2024 07:23:43 +0100 Subject: [PATCH 3/5] MultiDiscrete start value + info Supports start value of MultiDiscrete and now. Also updated changelog, +1 test_case, removed env_checker warning for usage of startvalue --- docs/misc/changelog.rst | 1 + stable_baselines3/common/env_checker.py | 30 ---------------- .../common/off_policy_algorithm.py | 2 +- .../common/on_policy_algorithm.py | 2 +- stable_baselines3/common/policies.py | 6 ++-- tests/test_spaces.py | 35 ++++++++++++++----- 6 files changed, 33 insertions(+), 43 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 1ccfbb5a1..0fab8b56c 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -18,6 +18,7 @@ New Features: Bug Fixes: ^^^^^^^^^^ +- Support for start value in discrete action spaces `SB3-Contrib`_ ^^^^^^^^^^^^^^ diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 0310bcfe7..75b79cef8 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -16,31 +16,6 @@ def _is_numpy_array_space(space: spaces.Space) -> bool: """ return not isinstance(space, (spaces.Dict, spaces.Tuple)) - -def _starts_at_zero(space: Union[spaces.Discrete, spaces.MultiDiscrete]) -> bool: - """ - Return False if a (Multi)Discrete space has a non-zero start. - """ - return np.allclose(space.start, np.zeros_like(space.start)) - - -def _check_non_zero_start(space: spaces.Space, space_type: str = "observation", key: str = "") -> None: - """ - :param space: Observation or action space - :param space_type: information about whether it is an observation or action space - (for the warning message) - :param key: When the observation space comes from a Dict space, we pass the - corresponding key to have more precise warning messages. Defaults to "". - """ - if isinstance(space, (spaces.Discrete, spaces.MultiDiscrete)) and not _starts_at_zero(space): - maybe_key = f"(key='{key}')" if key else "" - warnings.warn( - f"{type(space).__name__} {space_type} space {maybe_key} with a non-zero start (start={space.start}) " - "is not supported by Stable-Baselines3. " - f"You can use a wrapper or update your {space_type} space." - ) - - def _check_image_input(observation_space: spaces.Box, key: str = "") -> None: """ Check that the input will be compatible with Stable-Baselines @@ -87,7 +62,6 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act for key, space in observation_space.spaces.items(): if isinstance(space, spaces.Dict): nested_dict = True - _check_non_zero_start(space, "observation", key) if nested_dict: warnings.warn( @@ -115,8 +89,6 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "which is supported by SB3." ) - _check_non_zero_start(observation_space, "observation") - if isinstance(observation_space, spaces.Sequence): warnings.warn( "Sequence observation space is not supported by Stable-Baselines3. " @@ -124,8 +96,6 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "Note: The checks for returned values are skipped." ) - _check_non_zero_start(action_space, "action") - if not _is_numpy_array_space(action_space): warnings.warn( "The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. " diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index 501625eec..5b2dc8ef6 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -401,7 +401,7 @@ def _sample_action( # We store the scaled action in the buffer buffer_action = scaled_action action = self.policy.unscale_action(scaled_action) - elif isinstance(self.action_space, spaces.Discrete): + elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)): # Discrete case: Shift action values so every action starts from zero scaled_action = self.policy.scale_action(unscaled_action) diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 0e1ec1cc2..cdf69a45f 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -213,7 +213,7 @@ def collect_rollouts( # Otherwise, clip the actions to avoid out of bound error # as we are sampling from an unbounded Gaussian distribution unscaled_actions = np.clip(actions, self.action_space.low, self.action_space.high) - elif isinstance(self.action_space, spaces.Discrete): + elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)): # Scale actions to match action-space bounds unscaled_actions = self.policy.unscale_action(unscaled_actions) # type: ignore[assignment, arg-type] diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 9a42d2371..8dfce3ec0 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -377,7 +377,7 @@ def predict( # Actions could be on arbitrary scale, so clip the actions to avoid # out of bound error (e.g. if sampling from a Gaussian distribution) actions = np.clip(actions, self.action_space.low, self.action_space.high) # type: ignore[assignment, arg-type] - elif isinstance(self.action_space, spaces.Discrete): + elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)): # transform action to its action-space bounds starting from its defined start value actions = self.unscale_action(actions) # type: ignore[assignment, arg-type] @@ -406,7 +406,7 @@ def scale_action(self, action: np.ndarray) -> np.ndarray: # Box case low, high = self.action_space.low, self.action_space.high scaled_action = 2.0 * ((action - low) / (high - low)) - 1.0 - elif isinstance(self.action_space, spaces.Discrete): + elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)): # discrete actions case scaled_action = np.subtract(action, self.action_space.start) else: @@ -429,7 +429,7 @@ def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray: if isinstance(self.action_space, spaces.Box): low, high = self.action_space.low, self.action_space.high unscaled_action = low + (0.5 * (scaled_action + 1.0) * (high - low)) - elif isinstance(self.action_space, spaces.Discrete): + elif isinstance(self.action_space, (spaces.Discrete, spaces.MultiDiscrete)): # match discrete actions bounds unscaled_action = np.add(scaled_action, self.action_space.start) else: diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 6ceb88db9..1e3b2dcb4 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -56,11 +56,20 @@ def __init__(self, nvec): class ActionBoundsTestClass(DummyEnv): def step(self, action): if isinstance(self.action_space, spaces.Discrete): - assert np.all(action >= self.action_space.start), f"Discrete action {action} is below the lower bound {self.action_space.start}" - assert np.all(action <= self.action_space.start + self.action_space.n), f"Discrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.n}" + assert np.all(action >= self.action_space.start), ( + f"Discrete action {action} is below the lower bound {self.action_space.start}") + assert np.all(action <= self.action_space.start + self.action_space.n), ( + f"Discrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.n}") + elif isinstance(self.action_space, spaces.MultiDiscrete): + assert np.all(action >= self.action_space.start), ( + f"MultiDiscrete action {action} is below the lower bound {self.action_space.start}") + assert np.all(action <= self.action_space.start + self.action_space.nvec), ( + f"MultiDiscrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.nvec}") elif isinstance(self.action_space, spaces.Box): - assert np.all(action >= self.action_space.low), f"Action {action} is below the lower bound {self.action_space.low}" - assert np.all(action <= self.action_space.high), f"Action {action} is above the upper bound {self.action_space.high}" + assert np.all(action >= self.action_space.low), ( + f"Action {action} is below the lower bound {self.action_space.low}") + assert np.all(action <= self.action_space.high), ( + f"Action {action} is above the upper bound {self.action_space.high}") return self.observation_space.sample(), 0.0, False, False, {} @pytest.mark.parametrize( @@ -184,6 +193,7 @@ def test_float64_action_space(model_class, obs_space, action_space): [ # on-policy test (PPO, spaces.Discrete(5, start=-6543)), + (PPO, spaces.MultiDiscrete([4, 3], start=[-6543, 11])), (PPO, spaces.Box(low=2344, high=2345, shape=(3,), dtype=np.float32)), # off-policy test (DQN, spaces.Discrete(2, start=9923)), @@ -210,11 +220,20 @@ def test_space_bounds(model_class, action_space): action, _ = model.predict(initial_obs, deterministic=False) if isinstance(action_space, spaces.Discrete): - assert np.all(action >= action_space.start), f"Discrete action {action} is below the lower bound {action_space.start}" - assert np.all(action <= action_space.start + action_space.n), f"Discrete action {action} is above the upper bound {action_space.start}+{action_space.n}" + assert np.all(action >= action_space.start), ( + f"Discrete action {action} is below the lower bound {action_space.start}") + assert np.all(action <= action_space.start + action_space.n), ( + f"Discrete action {action} is above the upper bound {action_space.start}+{action_space.n}") + elif isinstance(action_space, spaces.MultiDiscrete): + assert np.all(action >= action_space.start), ( + f"MultiDiscrete action {action} is below the lower bound {action_space.start}") + assert np.all(action <= action_space.start + action_space.nvec), ( + f"MultiDiscrete action {action} is above the upper bound {action_space.start}+{action_space.nvec}") elif isinstance(action_space, spaces.Box): - assert np.all(action >= action_space.low), f"Action {action} is below the lower bound {action_space.low}" - assert np.all(action <= action_space.high), f"Action {action} is above the upper bound {action_space.high}" + assert np.all(action >= action_space.low), ( + f"Action {action} is below the lower bound {action_space.low}") + assert np.all(action <= action_space.high), ( + f"Action {action} is above the upper bound {action_space.high}") def test_multidim_binary_not_supported(): From 7eb6632e83cd3921df7927219845cf140f07622f Mon Sep 17 00:00:00 2001 From: JoshuaBluem <72627997+JoshuaBluem@users.noreply.github.com> Date: Fri, 6 Dec 2024 07:41:38 +0100 Subject: [PATCH 4/5] Updated Exception Message Reduced long message line --- stable_baselines3/common/policies.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 8dfce3ec0..ffeb5ec28 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -410,7 +410,7 @@ def scale_action(self, action: np.ndarray) -> np.ndarray: # discrete actions case scaled_action = np.subtract(action, self.action_space.start) else: - raise AssertionError(f"Trying to scale an action using an action space that is not a Box() or Discrete(): {self.action_space}") + raise NotImplementedError(f"Trying to scale an action using action space: {self.action_space}") return scaled_action @@ -433,7 +433,7 @@ def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray: # match discrete actions bounds unscaled_action = np.add(scaled_action, self.action_space.start) else: - raise AssertionError(f"Trying to unscale an action using an action space that is not a Box() or Discrete(): {self.action_space}") + raise NotImplementedError(f"Trying to unscale an action using action space: {self.action_space}") return unscaled_action From e93a43cbeb763ce93cc654002b53aadf097dc9bc Mon Sep 17 00:00:00 2001 From: JoshuaB <72627997+JoshuaBluem@users.noreply.github.com> Date: Fri, 6 Dec 2024 09:35:19 +0100 Subject: [PATCH 5/5] Formalities reformatting --- docs/misc/changelog.rst | 2 +- stable_baselines3/common/env_checker.py | 3 +- tests/test_envs.py | 10 ---- tests/test_spaces.py | 76 +++++++++++++------------ 4 files changed, 44 insertions(+), 47 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 0fab8b56c..2af50ab6d 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -1740,4 +1740,4 @@ And all the contributors: @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto @lutogniew @lbergmann1 @lukashass @BertrandDecoster @pseudo-rnd-thoughts @stefanbschneider @kyle-he @PatrickHelm @corentinlger @marekm4 @stagoverflow @rushitnshah @markscsmith @NickLucche @cschindlbeck @peteole @jak3122 @will-maclean -@brn-dev @jmacglashan @kplers +@brn-dev @jmacglashan @kplers @JoshuaBluem diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 75b79cef8..a43fced0b 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -16,6 +16,7 @@ def _is_numpy_array_space(space: spaces.Space) -> bool: """ return not isinstance(space, (spaces.Dict, spaces.Tuple)) + def _check_image_input(observation_space: spaces.Box, key: str = "") -> None: """ Check that the input will be compatible with Stable-Baselines @@ -59,7 +60,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act if isinstance(observation_space, spaces.Dict): nested_dict = False - for key, space in observation_space.spaces.items(): + for _key, space in observation_space.spaces.items(): if isinstance(space, spaces.Dict): nested_dict = True diff --git a/tests/test_envs.py b/tests/test_envs.py index 2fbce120c..293a188fb 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -121,14 +121,8 @@ def patched_step(_action): spaces.Dict({"position": spaces.Dict({"abs": spaces.Discrete(5), "rel": spaces.Discrete(2)})}), # Small image inside a dict spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}), - # Non zero start index - spaces.Discrete(3, start=-1), # 2D MultiDiscrete spaces.MultiDiscrete(np.array([[4, 4], [2, 3]])), - # Non zero start index (MultiDiscrete) - spaces.MultiDiscrete([4, 4], start=[1, 0]), - # Non zero start index inside a Dict - spaces.Dict({"obs": spaces.Discrete(3, start=1)}), ], ) def test_non_default_spaces(new_obs_space): @@ -166,10 +160,6 @@ def patched_step(_action): spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32), # Almost good, except for one dim spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32), - # Non zero start index - spaces.Discrete(3, start=-1), - # Non zero start index (MultiDiscrete) - spaces.MultiDiscrete([4, 4], start=[1, 0]), ], ) def test_non_default_action_spaces(new_action_space): diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 1e3b2dcb4..59b546da6 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -53,25 +53,31 @@ def __init__(self, nvec): BOX_SPACE_FLOAT32, ) + class ActionBoundsTestClass(DummyEnv): def step(self, action): if isinstance(self.action_space, spaces.Discrete): - assert np.all(action >= self.action_space.start), ( - f"Discrete action {action} is below the lower bound {self.action_space.start}") - assert np.all(action <= self.action_space.start + self.action_space.n), ( - f"Discrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.n}") + assert np.all( + action >= self.action_space.start + ), f"Discrete action {action} is below the lower bound {self.action_space.start}" + assert np.all( + action <= self.action_space.start + self.action_space.n + ), f"Discrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.n}" elif isinstance(self.action_space, spaces.MultiDiscrete): - assert np.all(action >= self.action_space.start), ( - f"MultiDiscrete action {action} is below the lower bound {self.action_space.start}") - assert np.all(action <= self.action_space.start + self.action_space.nvec), ( - f"MultiDiscrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.nvec}") + assert np.all( + action >= self.action_space.start + ), f"MultiDiscrete action {action} is below the lower bound {self.action_space.start}" + assert np.all( + action <= self.action_space.start + self.action_space.nvec + ), f"MultiDiscrete action {action} is above the upper bound {self.action_space.start}+{self.action_space.nvec}" elif isinstance(self.action_space, spaces.Box): - assert np.all(action >= self.action_space.low), ( - f"Action {action} is below the lower bound {self.action_space.low}") - assert np.all(action <= self.action_space.high), ( - f"Action {action} is above the upper bound {self.action_space.high}") + assert np.all(action >= self.action_space.low), f"Action {action} is below the lower bound {self.action_space.low}" + assert np.all( + action <= self.action_space.high + ), f"Action {action} is above the upper bound {self.action_space.high}" return self.observation_space.sample(), 0.0, False, False, {} + @pytest.mark.parametrize( "env", [ @@ -189,17 +195,17 @@ def test_float64_action_space(model_class, obs_space, action_space): @pytest.mark.parametrize( - "model_class, action_space", - [ - # on-policy test - (PPO, spaces.Discrete(5, start=-6543)), - (PPO, spaces.MultiDiscrete([4, 3], start=[-6543, 11])), - (PPO, spaces.Box(low=2344, high=2345, shape=(3,), dtype=np.float32)), - # off-policy test - (DQN, spaces.Discrete(2, start=9923)), - (SAC, spaces.Box(low=-123, high=-122, shape=(1,), dtype=np.float32)), - ], - ) + "model_class, action_space", + [ + # on-policy test + (PPO, spaces.Discrete(5, start=-6543)), + (PPO, spaces.MultiDiscrete([4, 3], start=[-6543, 11])), + (PPO, spaces.Box(low=2344, high=2345, shape=(3,), dtype=np.float32)), + # off-policy test + (DQN, spaces.Discrete(2, start=9923)), + (SAC, spaces.Box(low=-123, high=-122, shape=(1,), dtype=np.float32)), + ], +) def test_space_bounds(model_class, action_space): obs_space = BOX_SPACE_FLOAT32 env = ActionBoundsTestClass(obs_space, action_space) @@ -220,20 +226,20 @@ def test_space_bounds(model_class, action_space): action, _ = model.predict(initial_obs, deterministic=False) if isinstance(action_space, spaces.Discrete): - assert np.all(action >= action_space.start), ( - f"Discrete action {action} is below the lower bound {action_space.start}") - assert np.all(action <= action_space.start + action_space.n), ( - f"Discrete action {action} is above the upper bound {action_space.start}+{action_space.n}") + assert np.all(action >= action_space.start), f"Discrete action {action} is below the lower bound {action_space.start}" + assert np.all( + action <= action_space.start + action_space.n + ), f"Discrete action {action} is above the upper bound {action_space.start}+{action_space.n}" elif isinstance(action_space, spaces.MultiDiscrete): - assert np.all(action >= action_space.start), ( - f"MultiDiscrete action {action} is below the lower bound {action_space.start}") - assert np.all(action <= action_space.start + action_space.nvec), ( - f"MultiDiscrete action {action} is above the upper bound {action_space.start}+{action_space.nvec}") + assert np.all( + action >= action_space.start + ), f"MultiDiscrete action {action} is below the lower bound {action_space.start}" + assert np.all( + action <= action_space.start + action_space.nvec + ), f"MultiDiscrete action {action} is above the upper bound {action_space.start}+{action_space.nvec}" elif isinstance(action_space, spaces.Box): - assert np.all(action >= action_space.low), ( - f"Action {action} is below the lower bound {action_space.low}") - assert np.all(action <= action_space.high), ( - f"Action {action} is above the upper bound {action_space.high}") + assert np.all(action >= action_space.low), f"Action {action} is below the lower bound {action_space.low}" + assert np.all(action <= action_space.high), f"Action {action} is above the upper bound {action_space.high}" def test_multidim_binary_not_supported():