diff --git a/.github/ISSUE_TEMPLATE/custom_env.md b/.github/ISSUE_TEMPLATE/custom_env.md index 0a12a68bb..f28d370b8 100644 --- a/.github/ISSUE_TEMPLATE/custom_env.md +++ b/.github/ISSUE_TEMPLATE/custom_env.md @@ -44,19 +44,20 @@ from stable_baselines3.common.env_checker import check_env class CustomEnv(gym.Env): def __init__(self): - super(CustomEnv, self).__init__() + super().__init__() self.observation_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(14,)) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(6,)) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() reward = 1.0 done = False + truncated = False info = {} - return obs, reward, done, info + return obs, reward, done, truncated, info env = CustomEnv() check_env(env) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 20953d2f6..d9a3f7120 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,4 +1,4 @@ -image: stablebaselines/stable-baselines3-cpu:1.4.1a0 +image: stablebaselines/stable-baselines3-cpu:1.5.1a6 type-check: script: diff --git a/Dockerfile b/Dockerfile index 8dfbbbf4c..96588ef91 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,9 @@ FROM $PARENT_IMAGE ARG PYTORCH_DEPS=cpuonly ARG PYTHON_VERSION=3.7 +# for tzdata +ENV DEBIAN_FRONTEND="noninteractive" TZ="Europe/Paris" + RUN apt-get update && apt-get install -y --no-install-recommends \ build-essential \ cmake \ @@ -20,7 +23,7 @@ RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest ~/miniconda.sh -b -p /opt/conda && \ rm ~/miniconda.sh && \ /opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include && \ - /opt/conda/bin/conda install -y pytorch $PYTORCH_DEPS -c pytorch && \ + /opt/conda/bin/conda install -y pytorch=1.11 $PYTORCH_DEPS -c pytorch && \ /opt/conda/bin/conda clean -ya ENV PATH /opt/conda/bin:$PATH diff --git a/Makefile b/Makefile index 9954c7d7b..02851cf45 100644 --- a/Makefile +++ b/Makefile @@ -29,7 +29,8 @@ check-codestyle: commit-checks: format type lint doc: - cd docs && make html + # Prevent weird error due to protobuf + cd docs && PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp make html spelling: cd docs && make spelling diff --git a/README.md b/README.md index 973a180da..07872a661 100644 --- a/README.md +++ b/README.md @@ -124,12 +124,12 @@ env = gym.make("CartPole-v1") model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) -obs = env.reset() +obs, info = env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) + obs, reward, done, truncated, info = env.step(action) env.render() - if done: + if done or truncated: obs = env.reset() env.close() diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 98a550820..7b89ba92b 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -4,11 +4,11 @@ channels: - defaults dependencies: - cpuonly=1.0=0 - - pip=21.1 + - pip=22.1.1 - python=3.7 - - pytorch=1.11=py3.7_cpu_0 + - pytorch=1.11.0=py3.7_cpu_0 - pip: - - gym==0.21 + - gym==0.26 - cloudpickle - opencv-python-headless - pandas diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index 1b8f9fb7f..1a3ae34f4 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -95,9 +95,10 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t .. note:: By default the feature extractor is shared between the actor and the critic to save computation (when applicable). - However, this can be changed by defining a custom policy for on-policy algorithms or setting - ``share_features_extractor=False`` in the ``policy_kwargs`` for off-policy algorithms - (and when applicable). + However, this can be changed by defining a custom policy for on-policy algorithms + (see `issue #1066 `_ + for more information) or setting ``share_features_extractor=False`` in the + ``policy_kwargs`` for off-policy algorithms (and when applicable). .. code-block:: python diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 247c86cd9..426cd8fff 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -94,11 +94,12 @@ In the following example, we will train, save and load a DQN model on the Lunar mean_reward, std_reward = evaluate_policy(model, model.get_env(), n_eval_episodes=10) # Enjoy trained agent - obs = env.reset() + vec_env = model.get_env() + obs = vec_env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) - obs, rewards, dones, info = env.step(action) - env.render() + obs, rewards, dones, info = vec_env.step(action) + vec_env.render() Multiprocessing: Unleashing the Power of Vectorized Environments @@ -470,19 +471,19 @@ The parking env is a goal-conditioned continuous control task, in which the vehi # HER must be loaded with the env model = SAC.load("her_sac_highway", env=env) - obs = env.reset() + obs, info = env.reset() # Evaluate the agent episode_reward = 0 for _ in range(100): action, _ = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) + obs, reward, done, truncated, info = env.step(action) env.render() episode_reward += reward - if done or info.get("is_success", False): + if done or truncated or info.get("is_success", False): print("Reward:", episode_reward, "Success?", info.get("is_success", False)) episode_reward = 0.0 - obs = env.reset() + obs, info = env.reset() Learning Rate Schedule diff --git a/docs/guide/export.rst b/docs/guide/export.rst index b6884c19d..3a2174979 100644 --- a/docs/guide/export.rst +++ b/docs/guide/export.rst @@ -46,29 +46,40 @@ For PPO, assuming a shared feature extactor. .. code-block:: python + import torch as th + from stable_baselines3 import PPO - import torch - class OnnxablePolicy(torch.nn.Module): - def __init__(self, extractor, action_net, value_net): - super(OnnxablePolicy, self).__init__() - self.extractor = extractor - self.action_net = action_net - self.value_net = value_net - def forward(self, observation): - # NOTE: You may have to process (normalize) observation in the correct - # way before using this. See `common.preprocessing.preprocess_obs` - action_hidden, value_hidden = self.extractor(observation) - return self.action_net(action_hidden), self.value_net(value_hidden) + class OnnxablePolicy(th.nn.Module): + def __init__(self, extractor, action_net, value_net): + super().__init__() + self.extractor = extractor + self.action_net = action_net + self.value_net = value_net + + def forward(self, observation): + # NOTE: You may have to process (normalize) observation in the correct + # way before using this. See `common.preprocessing.preprocess_obs` + action_hidden, value_hidden = self.extractor(observation) + return self.action_net(action_hidden), self.value_net(value_hidden) - # Example: model = PPO("MlpPolicy", "Pendulum-v1") - model = PPO.load("PathToTrainedModel.zip") - model.policy.to("cpu") - onnxable_model = OnnxablePolicy(model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net) - dummy_input = torch.randn(1, observation_size) - torch.onnx.export(onnxable_model, dummy_input, "my_ppo_model.onnx", opset_version=9) + # Example: model = PPO("MlpPolicy", "Pendulum-v1") + model = PPO.load("PathToTrainedModel.zip", device="cpu") + onnxable_model = OnnxablePolicy( + model.policy.mlp_extractor, model.policy.action_net, model.policy.value_net + ) + + observation_size = model.observation_space.shape + dummy_input = th.randn(1, *observation_size) + th.onnx.export( + onnxable_model, + dummy_input, + "my_ppo_model.onnx", + opset_version=9, + input_names=["input"], + ) ##### Load and test with onnx @@ -76,48 +87,97 @@ For PPO, assuming a shared feature extactor. import onnxruntime as ort import numpy as np + onnx_path = "my_ppo_model.onnx" onnx_model = onnx.load(onnx_path) onnx.checker.check_model(onnx_model) - observation = np.zeros((1, observation_size)).astype(np.float32) + observation = np.zeros((1, *observation_size)).astype(np.float32) ort_sess = ort.InferenceSession(onnx_path) - action, value = ort_sess.run(None, {'input.1': observation}) + action, value = ort_sess.run(None, {"input": observation}) For SAC the procedure is similar. The example shown only exports the actor network as the actor is sufficient to roll out the trained policies. .. code-block:: python + import torch as th + from stable_baselines3 import SAC - import torch - class OnnxablePolicy(torch.nn.Module): - def __init__(self, actor): - super(OnnxablePolicy, self).__init__() - # Removing the flatten layer because it can't be onnxed - self.actor = torch.nn.Sequential(actor.latent_pi, actor.mu) + class OnnxablePolicy(th.nn.Module): + def __init__(self, actor: th.nn.Module): + super().__init__() + # Removing the flatten layer because it can't be onnxed + self.actor = th.nn.Sequential( + actor.latent_pi, + actor.mu, + # For gSDE + # th.nn.Hardtanh(min_val=-actor.clip_mean, max_val=actor.clip_mean), + # Squash the output + th.nn.Tanh(), + ) + + def forward(self, observation: th.Tensor) -> th.Tensor: + # NOTE: You may have to process (normalize) observation in the correct + # way before using this. See `common.preprocessing.preprocess_obs` + return self.actor(observation) - def forward(self, observation): - # NOTE: You may have to process (normalize) observation in the correct - # way before using this. See `common.preprocessing.preprocess_obs` - return self.actor(observation) - model = SAC.load("PathToTrainedModel.zip") + # Example: model = SAC("MlpPolicy", "Pendulum-v1") + model = SAC.load("PathToTrainedModel.zip", device="cpu") onnxable_model = OnnxablePolicy(model.policy.actor) - dummy_input = torch.randn(1, observation_size) - onnxable_model.policy.to("cpu") - torch.onnx.export(onnxable_model, dummy_input, "my_sac_actor.onnx", opset_version=9) + observation_size = model.observation_space.shape + dummy_input = th.randn(1, *observation_size) + th.onnx.export( + onnxable_model, + dummy_input, + "my_sac_actor.onnx", + opset_version=9, + input_names=["input"], + ) + + ##### Load and test with onnx + + import onnxruntime as ort + import numpy as np + + onnx_path = "my_sac_actor.onnx" + + observation = np.zeros((1, *observation_size)).astype(np.float32) + ort_sess = ort.InferenceSession(onnx_path) + action = ort_sess.run(None, {"input": observation}) For more discussion around the topic refer to this `issue. `_ -Export to C++ ------------------ +Trace/Export to C++ +------------------- + +You can use PyTorch JIT to trace and save a trained model that can be re-used in other applications +(for instance inference code written in C++). + +There is a draft PR in the RL Zoo about C++ export: https://github.com/DLR-RM/rl-baselines3-zoo/pull/228 + +.. code-block:: python + + # See "ONNX export" for imports and OnnxablePolicy + jit_path = "sac_traced.pt" + + # Trace and optimize the module + traced_module = th.jit.trace(onnxable_model.eval(), dummy_input) + frozen_module = th.jit.freeze(traced_module) + frozen_module = th.jit.optimize_for_inference(frozen_module) + th.jit.save(frozen_module, jit_path) + + ##### Load and test with torch + + import torch as th -(using PyTorch JIT) -TODO: help is welcomed! + dummy_input = th.randn(1, *observation_size) + loaded_module = th.jit.load(jit_path) + action_jit = loaded_module(dummy_input) Export to tensorflowjs / ONNX-JS diff --git a/docs/guide/quickstart.rst b/docs/guide/quickstart.rst index 064139d25..a1c547344 100644 --- a/docs/guide/quickstart.rst +++ b/docs/guide/quickstart.rst @@ -14,18 +14,24 @@ Here is a quick example of how to train and run A2C on a CartPole environment: from stable_baselines3 import A2C - env = gym.make('CartPole-v1') + env = gym.make("CartPole-v1") - model = A2C('MlpPolicy', env, verbose=1) + model = A2C("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10000) - obs = env.reset() + # Note: Gym 0.26+ reset() returns a tuple + # where SB3 VecEnv only return an observation + obs, info = env.reset() for i in range(1000): action, _state = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) + # Note: Gym 0.26+ step() returns an additional boolean + # "truncated" where SB3 store truncation information + # in info["TimeLimit.truncated"] + obs, reward, done, truncated, info = env.step(action) env.render() - if done: - obs = env.reset() + # Note: reset is automated in SB3 VecEnv + if done or truncated: + obs, info = env.reset() .. note:: @@ -40,4 +46,4 @@ the policy is registered: from stable_baselines3 import A2C - model = A2C('MlpPolicy', 'CartPole-v1').learn(10000) + model = A2C("MlpPolicy", "CartPole-v1").learn(10000) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 0d745e5cf..8ad6c2952 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,9 +3,11 @@ Changelog ========== -Release 1.6.1a4 (WIP) +Release 1.6.1 (2022-09-29) --------------------------- +**Bug fix release** + Breaking Changes: ^^^^^^^^^^^^^^^^^ - Switched minimum tensorboard version to 2.9.1 @@ -19,6 +21,7 @@ New Features: SB3-Contrib ^^^^^^^^^^^ +- Fixed the issue of wrongly passing policy arguments when using ``CnnLstmPolicy`` or ``MultiInputLstmPolicy`` with ``RecurrentPPO`` (@mlodel) Bug Fixes: ^^^^^^^^^^ @@ -29,7 +32,10 @@ Bug Fixes: - Fixed missing verbose parameter passing in the ``EvalCallback`` constructor (@burakdmb) - Fixed the issue that when updating the target network in DQN, SAC, TD3, the ``running_mean`` and ``running_var`` properties of batch norm layers are not updated (@honglu2875) - Fixed incorrect type annotation of the replay_buffer_class argument in ``common.OffPolicyAlgorithm`` initializer, where an instance instead of a class was required (@Rocamonde) +- Fixed loading saved model with different number of envrionments - Removed ``forward()`` abstract method declaration from ``common.policies.BaseModel`` (already defined in ``torch.nn.Module``) to fix type errors in subclasses (@Rocamonde) +- Fixed the return type of ``.load()`` and ``.learn()`` methods in ``BaseAlgorithm`` so that they now use ``TypeVar`` (@Rocamonde) +- Fixed an issue where keys with different tags but the same key raised an error in ``common.logger.HumanOutputFormat`` (@Rocamonde and @AdamGleave) Deprecations: ^^^^^^^^^^^^^ @@ -37,8 +43,8 @@ Deprecations: Others: ^^^^^^^ - Fixed ``DictReplayBuffer.next_observations`` typing (@qgallouedec) - - Added support for ``device="auto"`` in buffers and made it default (@qgallouedec) +- Updated ``ResultsWriter` (used internally by ``Monitor`` wrapper) to automatically create missing directories when ``filename`` is a path (@dominicgkerr) Documentation: ^^^^^^^^^^^^^^ @@ -48,7 +54,9 @@ Documentation: - Fixed typo in ppo doc (@francescoluciano) - Fixed typo in install doc(@jlp-ue) - Clarified and standardized verbosity documentation - +- Added link to a GitHub issue in the custom policy documentation (@AlexPasqua) +- Update doc on exporting models (fixes and added torch jit) +- Fixed typos (@Akhilez) Release 1.6.0 (2022-07-11) --------------------------- @@ -57,6 +65,7 @@ Release 1.6.0 (2022-07-11) Breaking Changes: ^^^^^^^^^^^^^^^^^ +- Switched minimum Gym version to 0.24 (@carlosluis) - Changed the way policy "aliases" are handled ("MlpPolicy", "CnnPolicy", ...), removing the former ``register_policy`` helper, ``policy_base`` parameter and using ``policy_aliases`` static attributes instead (@Gregwar) - SB3 now requires PyTorch >= 1.11 @@ -65,6 +74,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- ``noop_max`` and ``frame_skip`` are now allowed to be equal to zero when using ``AtariWrapper`` SB3-Contrib ^^^^^^^^^^^ @@ -90,6 +100,7 @@ Deprecations: Others: ^^^^^^^ - Upgraded to Python 3.7+ syntax using ``pyupgrade`` +- Updated docker base image to Ubuntu 20.04 and cuda 11.3 - Removed redundant double-check for nested observations from ``BaseAlgorithm._wrap_env`` (@TibiGG) Documentation: @@ -99,6 +110,7 @@ Documentation: - Added link to PPO ICLR blog post - Added remark about breaking Markov assumption and timeout handling - Added doc about MLFlow integration via custom logger (@git-thor) +- Updated tutorials to work with Gym 0.23 (@arjun-kg) - Updated Huggingface integration doc - Added copy button for code snippets - Added doc about EnvPool and Isaac Gym support @@ -111,7 +123,7 @@ Release 1.5.0 (2022-03-25) Breaking Changes: ^^^^^^^^^^^^^^^^^ -- Switched minimum Gym version to 0.21.0. +- Switched minimum Gym version to 0.21.0 New Features: ^^^^^^^^^^^^^ @@ -1035,5 +1047,6 @@ And all the contributors: @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede +@carlosluis @arjun-kg @tlpss @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 -@anand-bala @hughperkins @sidney-tio +@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde diff --git a/docs/modules/her.rst b/docs/modules/her.rst index 82bf745d6..7f3feefbc 100644 --- a/docs/modules/her.rst +++ b/docs/modules/her.rst @@ -22,7 +22,10 @@ It creates "virtual" transitions by relabeling transitions (changing the desired .. warning:: - HER requires the environment to inherits from `gym.GoalEnv `_ + HER requires the environment to follow the legacy `gym.GoalEnv interface `_ + In short, the ``gym.Env`` must have: + - a vectorized implementation of ``compute_reward()`` + - a dictionary observation space with three keys: ``observation``, ``achieved_goal`` and ``desired_goal`` .. warning:: diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh index 13ac86b17..3f0d5ae7c 100755 --- a/scripts/build_docker.sh +++ b/scripts/build_docker.sh @@ -1,14 +1,14 @@ #!/bin/bash -CPU_PARENT=ubuntu:18.04 -GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 +CPU_PARENT=ubuntu:20.04 +GPU_PARENT=nvidia/cuda:11.3.1-base-ubuntu20.04 TAG=stablebaselines/stable-baselines3 VERSION=$(cat ./stable_baselines3/version.txt) if [[ ${USE_GPU} == "True" ]]; then PARENT=${GPU_PARENT} - PYTORCH_DEPS="cudatoolkit=10.1" + PYTORCH_DEPS="cudatoolkit=11.3" else PARENT=${CPU_PARENT} PYTORCH_DEPS="cpuonly" diff --git a/setup.cfg b/setup.cfg index 5bc66c20c..bd04ac9e9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -10,11 +10,12 @@ filterwarnings = # Tensorboard warnings ignore::DeprecationWarning:tensorboard # Gym warnings - ignore:Parameters to load are deprecated.:DeprecationWarning - ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning + ; ignore:Parameters to load are deprecated.:DeprecationWarning + ; ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning ignore::UserWarning:gym - ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning - ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning + ; ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning + ; ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning + ignore:.*step API:DeprecationWarning:gym markers = expensive: marks tests as expensive (deselect with '-m "not expensive"') diff --git a/setup.py b/setup.py index 8c410a3fa..a4365a570 100644 --- a/setup.py +++ b/setup.py @@ -48,13 +48,13 @@ model = PPO("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10_000) -obs = env.reset() +obs, info = env.reset() for i in range(1000): action, _states = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) + obs, reward, done, truncated, info = env.step(action) env.render() - if done: - obs = env.reset() + if done or truncated: + obs, info = env.reset() ``` Or just train a model with a one liner if [the environment is registered in Gym](https://www.gymlibrary.ml/content/environment_creation/) and if [the policy is registered](https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html): @@ -73,7 +73,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym==0.21", # Fixed version due to breaking changes in 0.22 + "gym==0.26", "numpy", "torch>=1.11", # For saving models @@ -100,11 +100,9 @@ "isort>=5.0", # Reformat "black", - # For toy text Gym envs - "scipy>=1.4.1", ], "docs": [ - "sphinx", + "sphinx~=4.5.0", "sphinx-autobuild", "sphinx-rtd-theme", # For spelling @@ -117,8 +115,9 @@ "extra": [ # For render "opencv-python", + "pygame", # For atari games, - "ale-py==0.7.4", + "ale-py~=0.8.0", "autorom[accept-rom-license]~=0.4.2", "pillow", # Tensorboard support diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py index d73f5f095..4792f6c15 100644 --- a/stable_baselines3/__init__.py +++ b/stable_baselines3/__init__.py @@ -1,4 +1,5 @@ import os +import warnings from stable_baselines3.a2c import A2C from stable_baselines3.common.utils import get_system_info @@ -14,6 +15,9 @@ with open(version_file) as file_handler: __version__ = file_handler.read().strip() +# Silence Gym warnings due to new API +warnings.filterwarnings("ignore", message=r".*step API", module="gym") + def HER(*args, **kwargs): raise ImportError( diff --git a/stable_baselines3/a2c/a2c.py b/stable_baselines3/a2c/a2c.py index 8058f5264..8b8cecbfb 100644 --- a/stable_baselines3/a2c/a2c.py +++ b/stable_baselines3/a2c/a2c.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type, TypeVar, Union import torch as th from gym import spaces @@ -9,6 +9,8 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance +A2CSelf = TypeVar("A2CSelf", bound="A2C") + class A2C(OnPolicyAlgorithm): """ @@ -183,7 +185,7 @@ def train(self) -> None: self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) def learn( - self, + self: A2CSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 100, @@ -193,7 +195,7 @@ def learn( tb_log_name: str = "A2C", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> "A2C": + ) -> A2CSelf: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index a9b2eca1f..ba08da7b5 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -1,3 +1,5 @@ +from typing import Dict, Tuple + import gym import numpy as np from gym import spaces @@ -9,7 +11,7 @@ except ImportError: cv2 = None -from stable_baselines3.common.type_aliases import GymObs, GymStepReturn +from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn class NoopResetEnv(gym.Wrapper): @@ -28,19 +30,20 @@ def __init__(self, env: gym.Env, noop_max: int = 30): self.noop_action = 0 assert env.unwrapped.get_action_meanings()[0] == "NOOP" - def reset(self, **kwargs) -> np.ndarray: + def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: self.env.reset(**kwargs) if self.override_num_noops is not None: noops = self.override_num_noops else: - noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) + noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) assert noops > 0 obs = np.zeros(0) + info = {} for _ in range(noops): - obs, _, done, _ = self.env.step(self.noop_action) - if done: - obs = self.env.reset(**kwargs) - return obs + obs, _, done, truncated, info = self.env.step(self.noop_action) + if done or truncated: + obs, info = self.env.reset(**kwargs) + return obs, info class FireResetEnv(gym.Wrapper): @@ -55,15 +58,15 @@ def __init__(self, env: gym.Env): assert env.unwrapped.get_action_meanings()[1] == "FIRE" assert len(env.unwrapped.get_action_meanings()) >= 3 - def reset(self, **kwargs) -> np.ndarray: + def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: self.env.reset(**kwargs) - obs, _, done, _ = self.env.step(1) - if done: + obs, _, done, truncated, _ = self.env.step(1) + if done or truncated: self.env.reset(**kwargs) - obs, _, done, _ = self.env.step(2) - if done: + obs, _, done, truncated, _ = self.env.step(2) + if done or truncated: self.env.reset(**kwargs) - return obs + return obs, {} class EpisodicLifeEnv(gym.Wrapper): @@ -79,21 +82,21 @@ def __init__(self, env: gym.Env): self.lives = 0 self.was_real_done = True - def step(self, action: int) -> GymStepReturn: - obs, reward, done, info = self.env.step(action) + def step(self, action: int) -> Gym26StepReturn: + obs, reward, done, truncated, info = self.env.step(action) self.was_real_done = done # check current lives, make loss of life terminal, # then update lives to handle bonus lives lives = self.env.unwrapped.ale.lives() if 0 < lives < self.lives: - # for Qbert sometimes we stay in lives == 0 condtion for a few frames + # for Qbert sometimes we stay in lives == 0 condition for a few frames # so its important to keep lives > 0, so that we only reset once # the environment advertises done. done = True self.lives = lives - return obs, reward, done, info + return obs, reward, done, truncated, info - def reset(self, **kwargs) -> np.ndarray: + def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: """ Calls the Gym environment reset, only when lives are exhausted. This way all states are still reachable even though lives are episodic, @@ -103,12 +106,12 @@ def reset(self, **kwargs) -> np.ndarray: :return: the first observation of the environment """ if self.was_real_done: - obs = self.env.reset(**kwargs) + obs, info = self.env.reset(**kwargs) else: # no-op step to advance from terminal/lost life state - obs, _, _, _ = self.env.step(0) + obs, _, _, info = self.env.step(0) self.lives = self.env.unwrapped.ale.lives() - return obs + return obs, info class MaxAndSkipEnv(gym.Wrapper): @@ -125,7 +128,7 @@ def __init__(self, env: gym.Env, skip: int = 4): self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=env.observation_space.dtype) self._skip = skip - def step(self, action: int) -> GymStepReturn: + def step(self, action: int) -> Gym26StepReturn: """ Step the environment with the given action Repeat action, sum reward, and max over last observations. @@ -134,9 +137,10 @@ def step(self, action: int) -> GymStepReturn: :return: observation, reward, done, information """ total_reward = 0.0 - done = None + terminated = truncated = False for i in range(self._skip): - obs, reward, done, info = self.env.step(action) + obs, reward, terminated, truncated, info = self.env.step(action) + done = terminated or truncated if i == self._skip - 2: self._obs_buffer[0] = obs if i == self._skip - 1: @@ -148,9 +152,9 @@ def step(self, action: int) -> GymStepReturn: # doesn't matter max_frame = self._obs_buffer.max(axis=0) - return max_frame, total_reward, done, info + return max_frame, total_reward, terminated, truncated, info - def reset(self, **kwargs) -> GymObs: + def reset(self, **kwargs) -> Gym26ResetReturn: return self.env.reset(**kwargs) @@ -235,8 +239,10 @@ def __init__( terminal_on_life_loss: bool = True, clip_reward: bool = True, ): - env = NoopResetEnv(env, noop_max=noop_max) - env = MaxAndSkipEnv(env, skip=frame_skip) + if noop_max > 0: + env = NoopResetEnv(env, noop_max=noop_max) + if frame_skip > 0: + env = MaxAndSkipEnv(env, skip=frame_skip) if terminal_on_life_loss: env = EpisodicLifeEnv(env) if "FIRE" in env.unwrapped.get_action_meanings(): diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 72428e07b..202bfa65b 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -5,7 +5,7 @@ import time from abc import ABC, abstractmethod from collections import deque -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -23,6 +23,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import ( check_for_correct_spaces, + compat_gym_seed, get_device, get_schedule_fn, get_system_info, @@ -53,6 +54,9 @@ def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymE return env +BaseAlgorithmSelf = TypeVar("BaseAlgorithmSelf", bound="BaseAlgorithm") + + class BaseAlgorithm(ABC): """ The base of RL algorithms @@ -516,6 +520,11 @@ def set_env(self, env: GymEnv, force_reset: bool = True) -> None: # if it is not a VecEnv, make it a VecEnv # and do other transformations (dict obs, image transpose) if needed env = self._wrap_env(env, self.verbose) + assert env.num_envs == self.n_envs, ( + "The number of environments to be set is different from the number of environments in the model: " + f"({env.num_envs} != {self.n_envs}), whereas `set_env` requires them to be the same. To load a model with " + f"a different number of environments, you must use `{self.__class__.__name__}.load(path, env)` instead" + ) # Check that the observation spaces match check_for_correct_spaces(env, self.observation_space, self.action_space) # Update VecNormalize object @@ -532,7 +541,7 @@ def set_env(self, env: GymEnv, force_reset: bool = True) -> None: @abstractmethod def learn( - self, + self: BaseAlgorithmSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 100, @@ -542,7 +551,7 @@ def learn( n_eval_episodes: int = 5, eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> "BaseAlgorithm": + ) -> BaseAlgorithmSelf: """ Return a trained model. @@ -591,10 +600,12 @@ def set_random_seed(self, seed: Optional[int] = None) -> None: return set_random_seed(seed, using_cuda=self.device.type == th.device("cuda").type) self.action_space.seed(seed) + # self.env is always a VecEnv if self.env is not None: self.env.seed(seed) + # Eval env may be a gym.Env, hence the call to compat_gym_seed() if self.eval_env is not None: - self.eval_env.seed(seed) + compat_gym_seed(self.eval_env, seed=seed) def set_parameters( self, @@ -666,7 +677,7 @@ def set_parameters( @classmethod def load( - cls, + cls: Type[BaseAlgorithmSelf], path: Union[str, pathlib.Path, io.BufferedIOBase], env: Optional[GymEnv] = None, device: Union[th.device, str] = "auto", @@ -674,7 +685,7 @@ def load( print_system_info: bool = False, force_reset: bool = True, **kwargs, - ) -> "BaseAlgorithm": + ) -> BaseAlgorithmSelf: """ Load the model from a zip-file. Warning: ``load`` re-creates the model from scratch, it does not update it in-place! @@ -704,7 +715,10 @@ def load( get_system_info() data, params, pytorch_variables = load_from_zip_file( - path, device=device, custom_objects=custom_objects, print_system_info=print_system_info + path, + device=device, + custom_objects=custom_objects, + print_system_info=print_system_info, ) # Remove stored device information and replace with ours @@ -730,6 +744,9 @@ def load( # See issue https://github.com/DLR-RM/stable-baselines3/issues/597 if force_reset and data is not None: data["_last_obs"] = None + # `n_envs` must be updated. See issue https://github.com/DLR-RM/stable-baselines3/issues/1018 + if data is not None: + data["n_envs"] = env.num_envs else: # Use stored env, if one exists. If not, continue as is (can be used for predict) if "env" in data: diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index 59725312b..18969242e 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -574,7 +574,7 @@ def add( reward: np.ndarray, done: np.ndarray, infos: List[Dict[str, Any]], - ) -> None: + ) -> None: # pytype: disable=signature-mismatch # Copy to avoid modification by reference for key in self.observations.keys(): # Reshape needed when using multiple envs with discrete observations @@ -711,7 +711,7 @@ def add( episode_start: np.ndarray, value: th.Tensor, log_prob: th.Tensor, - ) -> None: + ) -> None: # pytype: disable=signature-mismatch """ :param obs: Observation :param action: Action diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index efc05e307..c383c8eda 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -1,5 +1,5 @@ import warnings -from typing import Union +from typing import Any, Dict, Union import gym import numpy as np @@ -93,7 +93,65 @@ def _check_nan(env: gym.Env) -> None: _, _, _, _ = vec_env.step(action) -def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spaces.Space, method_name: str) -> None: +def _is_goal_env(env: gym.Env) -> bool: + """ + Check if the env uses the convention for goal-conditioned envs (previously, the gym.GoalEnv interface) + """ + return hasattr(env, "compute_reward") + + +def _check_goal_env_obs(obs: dict, observation_space: spaces.Space, method_name: str) -> None: + """ + Check that an environment implementing the `compute_rewards()` method + (previously known as GoalEnv in gym) contains three elements, + namely `observation`, `desired_goal`, and `achieved_goal`. + """ + assert len(observation_space.spaces) == 3, ( + "A goal conditioned env must contain 3 observation keys: `observation`, `desired_goal`, and `achieved_goal`." + f"The current observation contains {len(observation_space.spaces)} keys: {list(observation_space.spaces.keys())}" + ) + + for key in ["observation", "achieved_goal", "desired_goal"]: + if key not in observation_space.spaces: + raise AssertionError( + f"The observation returned by the `{method_name}()` method of a goal-conditioned env requires the '{key}' " + "key to be part of the observation dictionary. " + f"Current keys are {list(observation_space.spaces.keys())}" + ) + + +def _check_goal_env_compute_reward( + obs: Dict[str, Union[np.ndarray, int]], + env: gym.Env, + reward: float, + info: Dict[str, Any], +): + """ + Check that reward is computed with `compute_reward` + and that the implementation is vectorized. + """ + achieved_goal, desired_goal = obs["achieved_goal"], obs["desired_goal"] + assert reward == env.compute_reward( + achieved_goal, desired_goal, info + ), "The reward was not computed with `compute_reward()`" + + achieved_goal, desired_goal = np.array(achieved_goal), np.array(desired_goal) + batch_achieved_goals = np.array([achieved_goal, achieved_goal]) + batch_desired_goals = np.array([desired_goal, desired_goal]) + if isinstance(achieved_goal, int) or len(achieved_goal.shape) == 0: + batch_achieved_goals = batch_achieved_goals.reshape(2, 1) + batch_desired_goals = batch_desired_goals.reshape(2, 1) + batch_infos = np.array([info, info]) + rewards = env.compute_reward(batch_achieved_goals, batch_desired_goals, batch_infos) + assert rewards.shape == (2,), f"Unexpected shape for vectorized computation of reward: {rewards.shape} != (2,)" + assert rewards[0] == reward, f"Vectorized computation of reward differs from single computation: {rewards[0]} != {reward}" + + +def _check_obs( + obs: Union[tuple, dict, np.ndarray, int], + observation_space: spaces.Space, + method_name: str, +) -> None: """ Check that the observation returned by the environment correspond to the declared one. @@ -139,9 +197,15 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action Check the returned values by the env when calling `.reset()` or `.step()` methods. """ # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists - obs = env.reset() - - if isinstance(observation_space, spaces.Dict): + reset_returns = env.reset() + assert isinstance(reset_returns, tuple), "`reset()` must return a tuple (obs, info)" + assert len(reset_returns) == 2, f"`reset()` must return a tuple of size 2 (obs, info), not {len(reset_returns)}" + obs, info = reset_returns + assert isinstance(info, dict), "The second element of the tuple return by `reset()` must be a dictionary" + + if _is_goal_env(env): + _check_goal_env_obs(obs, observation_space, "reset") + elif isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), "The observation returned by `reset()` must be a dictionary" if not obs.keys() == observation_space.spaces.keys(): @@ -162,12 +226,15 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action action = action_space.sample() data = env.step(action) - assert len(data) == 4, "The `step()` method must return four values: obs, reward, done, info" + assert len(data) == 5, "The `step()` method must return four values: obs, reward, terminated, truncated, info" # Unpack - obs, reward, done, info = data + obs, reward, terminated, truncated, info = data - if isinstance(observation_space, spaces.Dict): + if _is_goal_env(env): + _check_goal_env_obs(obs, observation_space, "step") + _check_goal_env_compute_reward(obs, env, reward, info) + elif isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary" if not obs.keys() == observation_space.spaces.keys(): @@ -187,18 +254,20 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # We also allow int because the reward will be cast to float assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float" - assert isinstance(done, bool), "The `done` signal must be a boolean" + assert isinstance(terminated, bool), "The `terminated` signal must be a boolean" + assert isinstance(truncated, bool), "The `truncated` signal must be a boolean" assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary" - if isinstance(env, gym.GoalEnv): - # For a GoalEnv, the keys are checked at reset + # Goal conditioned env + if hasattr(env, "compute_reward"): assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info) def _check_spaces(env: gym.Env) -> None: """ - Check that the observation and action spaces are defined - and inherit from gym.spaces.Space. + Check that the observation and action spaces are defined and inherit from gym.spaces.Space. For + envs that follow the goal-conditioned standard (previously, the gym.GoalEnv interface) we check + the observation space is gym.spaces.Dict """ # Helper to link to the code, because gym has no proper documentation gym_spaces = " cf https://github.com/openai/gym/blob/master/gym/spaces/" @@ -209,6 +278,11 @@ def _check_spaces(env: gym.Env) -> None: assert isinstance(env.observation_space, spaces.Space), "The observation space must inherit from gym.spaces" + gym_spaces assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces + if _is_goal_env(env): + assert isinstance( + env.observation_space, spaces.Dict + ), "Goal conditioned envs (previously gym.GoalEnv) require the observation space to be gym.spaces.Dict" + # Check render cannot be covered by CI def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> None: # pragma: no cover diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index 520c50a5f..82cfc3e96 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -5,6 +5,7 @@ from stable_baselines3.common.atari_wrappers import AtariWrapper from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.utils import compat_gym_seed from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv @@ -81,7 +82,7 @@ def _init(): else: env = env_id(**env_kwargs) if seed is not None: - env.seed(seed + rank) + compat_gym_seed(env, seed=seed + rank) env.action_space.seed(seed + rank) # Wrap the env in a Monitor wrapper # to have additional training information diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index a881b32c9..0fc93a6cf 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -1,14 +1,14 @@ from collections import OrderedDict -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np -from gym import GoalEnv, spaces +from gym import Env, spaces from gym.envs.registration import EnvSpec -from stable_baselines3.common.type_aliases import GymStepReturn +from stable_baselines3.common.type_aliases import Gym26StepReturn -class BitFlippingEnv(GoalEnv): +class BitFlippingEnv(Env): """ Simple bit flipping env, useful to test HER. The goal is to flip all the bits to get a vector of ones. @@ -25,7 +25,7 @@ class BitFlippingEnv(GoalEnv): :param channel_first: Whether to use channel-first or last image. """ - spec = EnvSpec("BitFlippingEnv-v0") + spec = EnvSpec("BitFlippingEnv-v0", "no-entry-point") def __init__( self, @@ -157,12 +157,20 @@ def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: ] ) - def reset(self) -> Dict[str, Union[int, np.ndarray]]: + def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]: + if seed is not None: + self.obs_space.seed(seed) self.current_step = 0 self.state = self.obs_space.sample() - return self._get_obs() + return self._get_obs(), {} - def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: + """ + Step into the env. + + :param action: + :return: + """ if self.continuous: self.state[action > 0] = 1 - self.state[action > 0] else: @@ -173,8 +181,9 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: self.current_step += 1 # Episode terminate when we reached the goal or the max number of steps info = {"is_success": done} - done = done or self.current_step >= self.max_steps - return obs, reward, done, info + truncated = self.current_step >= self.max_steps + done = done or truncated + return obs, reward, done, truncated, info def compute_reward( self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]] diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py index 8f6ccd2dc..aa7169611 100644 --- a/stable_baselines3/common/envs/identity_env.py +++ b/stable_baselines3/common/envs/identity_env.py @@ -1,10 +1,10 @@ -from typing import Optional, Union +from typing import Dict, Optional, Tuple, Union import numpy as np from gym import Env, Space from gym.spaces import Box, Discrete, MultiBinary, MultiDiscrete -from stable_baselines3.common.type_aliases import GymObs, GymStepReturn +from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn class IdentityEnv(Env): @@ -32,18 +32,20 @@ def __init__(self, dim: Optional[int] = None, space: Optional[Space] = None, ep_ self.num_resets = -1 # Becomes 0 after __init__ exits. self.reset() - def reset(self) -> GymObs: + def reset(self, seed: Optional[int] = None) -> Gym26ResetReturn: + if seed is not None: + super().reset(seed=seed) self.current_step = 0 self.num_resets += 1 self._choose_next_state() - return self.state + return self.state, {} - def step(self, action: Union[int, np.ndarray]) -> GymStepReturn: + def step(self, action: Union[int, np.ndarray]) -> Gym26StepReturn: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 - done = self.current_step >= self.ep_length - return self.state, reward, done, {} + done = truncated = self.current_step >= self.ep_length + return self.state, reward, done, truncated, {} def _choose_next_state(self) -> None: self.state = self.action_space.sample() @@ -69,12 +71,12 @@ def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_l super().__init__(ep_length=ep_length, space=space) self.eps = eps - def step(self, action: np.ndarray) -> GymStepReturn: + def step(self, action: np.ndarray) -> Gym26StepReturn: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 - done = self.current_step >= self.ep_length - return self.state, reward, done, {} + done = truncated = self.current_step >= self.ep_length + return self.state, reward, done, truncated, {} def _get_reward(self, action: np.ndarray) -> float: return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0 @@ -136,15 +138,17 @@ def __init__( self.ep_length = 10 self.current_step = 0 - def reset(self) -> np.ndarray: + def reset(self, seed: Optional[int] = None) -> Tuple[np.ndarray, Dict]: + if seed is not None: + super().reset(seed=seed) self.current_step = 0 - return self.observation_space.sample() + return self.observation_space.sample(), {} - def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: reward = 0.0 self.current_step += 1 - done = self.current_step >= self.ep_length - return self.observation_space.sample(), reward, done, {} + done = truncated = self.current_step >= self.ep_length + return self.observation_space.sample(), reward, done, truncated, {} def render(self, mode: str = "human") -> None: pass diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index 2e5f13f61..c7b5973da 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -1,9 +1,9 @@ -from typing import Dict, Union +from typing import Dict, Optional, Tuple, Union import gym import numpy as np -from stable_baselines3.common.type_aliases import GymStepReturn +from stable_baselines3.common.type_aliases import Gym26StepReturn class SimpleMultiObsEnv(gym.Env): @@ -120,7 +120,7 @@ def init_possible_transitions(self) -> None: self.right_possible = [0, 1, 2, 12, 13, 14] self.up_possible = [4, 8, 12, 7, 11, 15] - def step(self, action: Union[int, float, np.ndarray]) -> GymStepReturn: + def step(self, action: Union[int, float, np.ndarray]) -> Gym26StepReturn: """ Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling `reset()` @@ -152,11 +152,12 @@ def step(self, action: Union[int, float, np.ndarray]) -> GymStepReturn: got_to_end = self.state == self.max_state reward = 1 if got_to_end else reward - done = self.count > self.max_count or got_to_end + truncated = self.count > self.max_count + done = got_to_end or truncated self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}" - return self.get_state_mapping(), reward, done, {"got_to_end": got_to_end} + return self.get_state_mapping(), reward, done, truncated, {"got_to_end": got_to_end} def render(self, mode: str = "human") -> None: """ @@ -166,15 +167,18 @@ def render(self, mode: str = "human") -> None: """ print(self.log) - def reset(self) -> Dict[str, np.ndarray]: + def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, np.ndarray], Dict]: """ Resets the environment state and step count and returns reset observation. + :param seed: :return: observation dict {'vec': ..., 'img': ...} """ + if seed is not None: + super().reset(seed=seed) self.count = 0 if not self.random_start: self.state = 0 else: self.state = np.random.randint(0, self.max_state) - return self.state_mapping[self.state] + return self.state_mapping[self.state], {} diff --git a/stable_baselines3/common/logger.py b/stable_baselines3/common/logger.py index c1e8433cc..31e5655c4 100644 --- a/stable_baselines3/common/logger.py +++ b/stable_baselines3/common/logger.py @@ -193,30 +193,31 @@ def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None: if key.find("/") > 0: # Find tag and add it to the dict tag = key[: key.find("/") + 1] - key2str[self._truncate(tag)] = "" + key2str[(tag, self._truncate(tag))] = "" # Remove tag from key if tag is not None and tag in key: key = str(" " + key[len(tag) :]) truncated_key = self._truncate(key) - if truncated_key in key2str: + if (tag, truncated_key) in key2str: raise ValueError( f"Key '{key}' truncated to '{truncated_key}' that already exists. Consider increasing `max_length`." ) - key2str[truncated_key] = self._truncate(value_str) + key2str[(tag, truncated_key)] = self._truncate(value_str) # Find max widths if len(key2str) == 0: warnings.warn("Tried to write empty key-value dict") return else: - key_width = max(map(len, key2str.keys())) + tagless_keys = map(lambda x: x[1], key2str.keys()) + key_width = max(map(len, tagless_keys)) val_width = max(map(len, key2str.values())) # Write out the data dashes = "-" * (key_width + val_width + 7) lines = [dashes] - for key, value in key2str.items(): + for (_, key), value in key2str.items(): key_space = " " * (key_width - len(key)) val_space = " " * (val_width - len(value)) lines.append(f"| {key}{key_space} | {value}{val_space} |") diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index 9a07b038a..499753553 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -11,7 +11,7 @@ import numpy as np import pandas -from stable_baselines3.common.type_aliases import GymObs, GymStepReturn +from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn class Monitor(gym.Wrapper): @@ -61,7 +61,7 @@ def __init__( self.total_steps = 0 self.current_reset_info = {} # extra info about the current episode, that was passed in during reset() - def reset(self, **kwargs) -> GymObs: + def reset(self, **kwargs) -> Gym26ResetReturn: """ Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True @@ -82,7 +82,7 @@ def reset(self, **kwargs) -> GymObs: self.current_reset_info[key] = value return self.env.reset(**kwargs) - def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: """ Step the environment with the given action @@ -91,9 +91,9 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: """ if self.needs_reset: raise RuntimeError("Tried to step environment that needs reset") - observation, reward, done, info = self.env.step(action) + observation, reward, done, truncated, info = self.env.step(action) self.rewards.append(reward) - if done: + if done or truncated: self.needs_reset = True ep_rew = sum(self.rewards) ep_len = len(self.rewards) @@ -108,7 +108,7 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: self.results_writer.write_row(ep_info) info["episode"] = ep_info self.total_steps += 1 - return observation, reward, done, info + return observation, reward, done, truncated, info def close(self) -> None: """ @@ -163,9 +163,10 @@ class ResultsWriter: """ A result writer that saves the data from the `Monitor` class - :param filename: the location to save a log file, can be None for no log + :param filename: the location to save a log file. When it does not end in + the string ``"monitor.csv"``, this suffix will be appended to it :param header: the header dictionary object of the saved csv - :param reset_keywords: the extra information to log, typically is composed of + :param extra_keys: the extra information to log, typically is composed of ``reset_keywords`` and ``info_keywords`` :param override_existing: appends to file if ``filename`` exists, otherwise override existing files (default) @@ -185,6 +186,9 @@ def __init__( filename = os.path.join(filename, Monitor.EXT) else: filename = filename + "." + Monitor.EXT + filename = os.path.realpath(filename) + # Create (if any) missing filename directories + os.makedirs(os.path.dirname(filename), exist_ok=True) # Append mode when not overridding existing file mode = "w" if override_existing else "a" # Prevent newline issue on Windows, see GH issue #692 diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index d2574edb9..f53e07e9c 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -4,7 +4,7 @@ import time import warnings from copy import deepcopy -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -21,6 +21,8 @@ from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.her.her_replay_buffer import HerReplayBuffer +OffPolicyAlgorithmSelf = TypeVar("OffPolicyAlgorithmSelf", bound="OffPolicyAlgorithm") + class OffPolicyAlgorithm(BaseAlgorithm): """ @@ -319,7 +321,7 @@ def _setup_learn( ) def learn( - self, + self: OffPolicyAlgorithmSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, @@ -329,7 +331,7 @@ def learn( tb_log_name: str = "run", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> "OffPolicyAlgorithm": + ) -> OffPolicyAlgorithmSelf: total_timesteps, callback = self._setup_learn( total_timesteps, @@ -616,7 +618,6 @@ def collect_rollouts( # Log training infos if log_interval is not None and self._episode_num % log_interval == 0: self._dump_logs() - callback.on_rollout_end() return RolloutReturn(num_collected_steps * env.num_envs, num_collected_episodes, continue_training) diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index a9c0a4138..0589fe17b 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -1,6 +1,6 @@ import sys import time -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -14,6 +14,8 @@ from stable_baselines3.common.utils import obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv +OnPolicyAlgorithmSelf = TypeVar("OnPolicyAlgorithmSelf", bound="OnPolicyAlgorithm") + class OnPolicyAlgorithm(BaseAlgorithm): """ @@ -141,7 +143,7 @@ def collect_rollouts( :param callback: Callback that will be called at each step (and at the beginning and end of the rollout) :param rollout_buffer: Buffer to fill with rollouts - :param n_steps: Number of experiences to collect per environment + :param n_rollout_steps: Number of experiences to collect per environment :return: True if function returned with at least `n_rollout_steps` collected, False if callback terminated rollout prematurely. """ @@ -225,7 +227,7 @@ def train(self) -> None: raise NotImplementedError def learn( - self, + self: OnPolicyAlgorithmSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, @@ -235,7 +237,7 @@ def learn( tb_log_name: str = "OnPolicyAlgorithm", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> "OnPolicyAlgorithm": + ) -> OnPolicyAlgorithmSelf: iteration = 0 total_timesteps, callback = self._setup_learn( diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index cd372e14b..4632e48fb 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -5,7 +5,7 @@ import warnings from abc import ABC, abstractmethod from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -33,8 +33,10 @@ from stable_baselines3.common.type_aliases import Schedule from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor +BaseModelSelf = TypeVar("BaseModelSelf", bound="BaseModel") -class BaseModel(nn.Module, ABC): + +class BaseModel(nn.Module): """ The base model object: makes predictions in response to observations. @@ -158,7 +160,7 @@ def save(self, path: str) -> None: th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) @classmethod - def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel": + def load(cls: Type[BaseModelSelf], path: str, device: Union[th.device, str] = "auto") -> BaseModelSelf: """ Load model from path. @@ -251,7 +253,7 @@ def obs_to_tensor(self, observation: Union[np.ndarray, Dict[str, np.ndarray]]) - return observation, vectorized_env -class BasePolicy(BaseModel): +class BasePolicy(BaseModel, ABC): """The base policy object. Parameters are mostly the same as `BaseModel`; additions are documented below. diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index f4c29ab27..509169dc7 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -11,7 +11,9 @@ GymEnv = Union[gym.Env, vec_env.VecEnv] GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] +Gym26ResetReturn = Tuple[GymObs, Dict] GymStepReturn = Tuple[GymObs, float, bool, Dict] +Gym26StepReturn = Tuple[GymObs, float, bool, bool, Dict] TensorDict = Dict[Union[str, int], th.Tensor] OptimizerStateDict = Dict[str, Any] MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 53c642cbd..f3c88f236 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -3,6 +3,7 @@ import platform import random from collections import deque +from inspect import signature from itertools import zip_longest from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -19,7 +20,13 @@ SummaryWriter = None from stable_baselines3.common.logger import Logger, configure -from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit +from stable_baselines3.common.type_aliases import ( + GymEnv, + Schedule, + TensorDict, + TrainFreq, + TrainFrequencyUnit, +) def set_random_seed(seed: int, using_cuda: bool = False) -> None: @@ -519,3 +526,18 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: if print_info: print(env_info_str) return env_info, env_info_str + + +def compat_gym_seed(env: GymEnv, seed: int) -> None: + """ + Compatibility helper to seed Gym envs. + + :param env: The Gym environment. + :param seed: The seed for the pseudo random generator + """ + if isinstance(env, gym.Env) and "seed" in signature(env.unwrapped.reset).parameters: + # gym >= 0.23.1 + env.reset(seed=seed) + else: + # VecEnv and backward compatibility + env.seed(seed) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 98706050c..8c09d3c8c 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -59,6 +59,7 @@ def __init__(self, num_envs: int, observation_space: gym.spaces.Space, action_sp self.num_envs = num_envs self.observation_space = observation_space self.action_space = action_space + self.reset_infos = [{} for _ in range(num_envs)] # store info returns by the reset method @abstractmethod def reset(self) -> VecEnvObs: diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index c0efc8caf..c663558a7 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -39,28 +39,36 @@ def step_async(self, actions: np.ndarray) -> None: self.actions = actions def step_wait(self) -> VecEnvStepReturn: + # Avoid circular imports for env_idx in range(self.num_envs): - obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( + obs, self.buf_rews[env_idx], done, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step( self.actions[env_idx] ) + # convert to SB3 VecEnv api + self.buf_dones[env_idx] = done or truncated + self.buf_infos[env_idx]["TimeLimit.truncated"] = truncated + if self.buf_dones[env_idx]: # save final observation where user can get it, then reset self.buf_infos[env_idx]["terminal_observation"] = obs - obs = self.envs[env_idx].reset() + obs, self.reset_infos[env_idx] = self.envs[env_idx].reset() self._save_obs(env_idx, obs) return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + # Avoid circular import + from stable_baselines3.common.utils import compat_gym_seed + if seed is None: seed = np.random.randint(0, 2**32 - 1) seeds = [] for idx, env in enumerate(self.envs): - seeds.append(env.seed(seed + idx)) + seeds.append(compat_gym_seed(env, seed=seed + idx)) return seeds def reset(self) -> VecEnvObs: for env_idx in range(self.num_envs): - obs = self.envs[env_idx].reset() + obs, self.reset_infos[env_idx] = self.envs[env_idx].reset() self._save_obs(env_idx, obs) return self._obs_from_buf() diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index 733b72833..88d725e66 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -199,7 +199,7 @@ def stack_observation_space(self, observation_space: spaces.Dict) -> spaces.Dict spaces_dict[key] = spaces.Box(low=low, high=high, dtype=subspace.dtype) return spaces.Dict(spaces=spaces_dict) - def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: # pytype: disable=signature-mismatch """ Resets the stacked observations, adds the reset observation to the stack, and returns the stack @@ -219,7 +219,7 @@ def update( observations: Dict[str, np.ndarray], dones: np.ndarray, infos: List[Dict[str, Any]], - ) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]: + ) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]: # pytype: disable=signature-mismatch """ Adds the observations to the stack and uses the dones to update the infos. diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index f723c71f7..367a87f13 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -14,29 +14,36 @@ ) -def _worker( - remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper +def _worker( # noqa: C901 + remote: mp.connection.Connection, + parent_remote: mp.connection.Connection, + env_fn_wrapper: CloudpickleWrapper, ) -> None: # Import here to avoid a circular import from stable_baselines3.common.env_util import is_wrapped + from stable_baselines3.common.utils import compat_gym_seed parent_remote.close() env = env_fn_wrapper.var() + reset_info = {} while True: try: cmd, data = remote.recv() if cmd == "step": - observation, reward, done, info = env.step(data) + observation, reward, done, truncated, info = env.step(data) + # convert to SB3 VecEnv api + done = done or truncated + info["TimeLimit.truncated"] = truncated if done: # save final observation where user can get it, then reset info["terminal_observation"] = observation - observation = env.reset() - remote.send((observation, reward, done, info)) + observation, reset_info = env.reset() + remote.send((observation, reward, done, info, reset_info)) elif cmd == "seed": - remote.send(env.seed(data)) + remote.send(compat_gym_seed(env, seed=data)) elif cmd == "reset": - observation = env.reset() - remote.send(observation) + observation, reset_info = env.reset() + remote.send((observation, reset_info)) elif cmd == "render": remote.send(env.render(data)) elif cmd == "close": @@ -119,7 +126,7 @@ def step_async(self, actions: np.ndarray) -> None: def step_wait(self) -> VecEnvStepReturn: results = [remote.recv() for remote in self.remotes] self.waiting = False - obs, rews, dones, infos = zip(*results) + obs, rews, dones, infos, self.reset_infos = zip(*results) return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: @@ -132,7 +139,8 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: def reset(self) -> VecEnvObs: for remote in self.remotes: remote.send(("reset", None)) - obs = [remote.recv() for remote in self.remotes] + results = [remote.recv() for remote in self.remotes] + obs, self.reset_infos = zip(*results) return _flatten_obs(obs, self.observation_space) def close(self) -> None: diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index e06d5125e..5fdb866f8 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -55,8 +55,7 @@ def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: """ Reset all environments """ - observation = self.venv.reset() # pytype:disable=annotation-type-mismatch - + observation = self.venv.reset() observation = self.stackedobs.reset(observation) return observation diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index a9244e737..531acd1ad 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -1,14 +1,15 @@ -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union import torch as th from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise -from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.td3.policies import TD3Policy from stable_baselines3.td3.td3 import TD3 +DDPGSelf = TypeVar("DDPGSelf", bound="DDPG") + class DDPG(TD3): """ @@ -116,7 +117,7 @@ def __init__( self._setup_model() def learn( - self, + self: DDPGSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, @@ -126,7 +127,7 @@ def learn( tb_log_name: str = "DDPG", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> OffPolicyAlgorithm: + ) -> DDPGSelf: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 80e024b3a..ea7d9f316 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -14,6 +14,8 @@ from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update from stable_baselines3.dqn.policies import CnnPolicy, DQNPolicy, MlpPolicy, MultiInputPolicy +DQNSelf = TypeVar("DQNSelf", bound="DQN") + class DQN(OffPolicyAlgorithm): """ @@ -255,7 +257,7 @@ def predict( return action, state def learn( - self, + self: DQNSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, @@ -265,7 +267,7 @@ def learn( tb_log_name: str = "DQN", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> OffPolicyAlgorithm: + ) -> DQNSelf: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/ppo/ppo.py b/stable_baselines3/ppo/ppo.py index 6bb9c23d0..cfcdfb12b 100644 --- a/stable_baselines3/ppo/ppo.py +++ b/stable_baselines3/ppo/ppo.py @@ -1,5 +1,5 @@ import warnings -from typing import Any, Dict, Optional, Type, Union +from typing import Any, Dict, Optional, Type, TypeVar, Union import numpy as np import torch as th @@ -11,6 +11,8 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance, get_schedule_fn +PPOSelf = TypeVar("PPOSelf", bound="PPO") + class PPO(OnPolicyAlgorithm): """ @@ -232,7 +234,7 @@ def train(self) -> None: # No clipping values_pred = values else: - # Clip the different between old and new value + # Clip the difference between old and new value # NOTE: this depends on the reward scaling values_pred = rollout_data.old_values + th.clamp( values - rollout_data.old_values, -clip_range_vf, clip_range_vf @@ -297,7 +299,7 @@ def train(self) -> None: self.logger.record("train/clip_range_vf", clip_range_vf) def learn( - self, + self: PPOSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 1, @@ -307,7 +309,7 @@ def learn( tb_log_name: str = "PPO", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> "PPO": + ) -> PPOSelf: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index de08b756a..8505e88b7 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -13,6 +13,8 @@ from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, SACPolicy +SACSelf = TypeVar("SACSelf", bound="SAC") + class SAC(OffPolicyAlgorithm): """ @@ -289,7 +291,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None: self.logger.record("train/ent_coef_loss", np.mean(ent_coef_losses)) def learn( - self, + self: SACSelf, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, @@ -299,7 +301,7 @@ def learn( tb_log_name: str = "SAC", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> OffPolicyAlgorithm: + ) -> SACSelf: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index 51df755cb..ae7895dc8 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union import gym import numpy as np @@ -13,6 +13,8 @@ from stable_baselines3.common.utils import get_parameters_by_name, polyak_update from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy, TD3Policy +TD3Self = TypeVar("TD3Self", bound="TD3") + class TD3(OffPolicyAlgorithm): """ @@ -152,7 +154,6 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: self._update_learning_rate([self.actor.optimizer, self.critic.optimizer]) actor_losses, critic_losses = [], [] - for _ in range(gradient_steps): self._n_updates += 1 @@ -205,7 +206,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None: self.logger.record("train/critic_loss", np.mean(critic_losses)) def learn( - self, + self: TD3Self, total_timesteps: int, callback: MaybeCallback = None, log_interval: int = 4, @@ -215,7 +216,7 @@ def learn( tb_log_name: str = "TD3", eval_log_path: Optional[str] = None, reset_num_timesteps: bool = True, - ) -> OffPolicyAlgorithm: + ) -> TD3Self: return super().learn( total_timesteps=total_timesteps, diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 9c2a9af99..9c6d6293b 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.6.1a4 +1.6.1 diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 0e028e670..4bd2d2793 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -27,15 +27,15 @@ def __init__(self): def reset(self): self._t = 0 obs = self._observations[0] - return obs + return obs, {} def step(self, action): self._t += 1 index = self._t % len(self._observations) obs = self._observations[index] - done = self._t >= self._ep_length + done = truncated = self._t >= self._ep_length reward = self._rewards[index] - return obs, reward, done, {} + return obs, reward, done, truncated, {} class DummyDictEnv(gym.Env): @@ -55,15 +55,15 @@ def __init__(self): def reset(self): self._t = 0 obs = {key: self._observations[0] for key in self.observation_space.spaces.keys()} - return obs + return obs, {} def step(self, action): self._t += 1 index = self._t % len(self._observations) obs = {key: self._observations[index] for key in self.observation_space.spaces.keys()} - done = self._t >= self._ep_length + done = truncated = self._t >= self._ep_length reward = self._rewards[index] - return obs, reward, done, {} + return obs, reward, done, truncated, {} @pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer]) diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 2c7e0ba32..275992db9 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -102,7 +102,7 @@ def test_callbacks(tmp_path, model_class): def select_env(model_class) -> str: if model_class is DQN: - return "CartPole-v0" + return "CartPole-v1" else: return "Pendulum-v1" diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 03f089db9..48dce0821 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -35,7 +35,7 @@ def test_cnn(tmp_path, model_class): # FakeImageEnv is channel last by default and should be wrapped assert is_vecenv_wrapped(model.get_env(), VecTransposeImage) - obs = env.reset() + obs, _ = env.reset() # Test stochastic predict with channel last input if model_class == DQN: @@ -238,7 +238,7 @@ def test_channel_first_env(tmp_path): assert not is_vecenv_wrapped(model.get_env(), VecTransposeImage) - obs = env.reset() + obs, _ = env.reset() action, _ = model.predict(obs, deterministic=True) diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 93b13b40e..1f832bfeb 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -1,3 +1,5 @@ +from typing import Optional + import gym import numpy as np import pytest @@ -66,14 +68,16 @@ def seed(self, seed=None): def step(self, action): reward = 0.0 - done = False - return self.observation_space.sample(), reward, done, {} + done = truncated = False + return self.observation_space.sample(), reward, done, truncated, {} def compute_reward(self, achieved_goal, desired_goal, info): return np.zeros((len(achieved_goal),)) - def reset(self): - return self.observation_space.sample() + def reset(self, seed: Optional[int] = None): + if seed is not None: + self.observation_space.seed(seed) + return self.observation_space.sample(), {} def render(self, mode="human"): pass @@ -105,7 +109,7 @@ def test_consistency(model_class): dict_env = gym.wrappers.TimeLimit(dict_env, 100) env = gym.wrappers.FlattenObservation(dict_env) dict_env.seed(10) - obs = dict_env.reset() + obs, _ = dict_env.reset() kwargs = {} n_steps = 256 diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index 0b0a82d8f..2313defc9 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -14,11 +14,12 @@ def step(self, action): observation = np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype) reward = 1 done = True + truncated = False info = {} - return observation, reward, done, info + return observation, reward, done, truncated, info def reset(self): - return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype) + return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype), {} def render(self, mode="human"): pass diff --git a/tests/test_envs.py b/tests/test_envs.py index 8b8cb8cb8..5c00b9473 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -28,7 +28,7 @@ ] -@pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v1"]) +@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v1"]) def test_env(env_id): """ Check that environmnent integrated in Gym pass the test. @@ -75,6 +75,17 @@ def test_bit_flipping(kwargs): # No warnings for custom envs assert len(record) == 0 + # Remove a key, must throw an error + obs_space = env.observation_space.spaces["observation"] + del env.observation_space.spaces["observation"] + with pytest.raises(AssertionError): + check_env(env) + + # Rename a key, must throw an error + env.observation_space.spaces["obs"] = obs_space + with pytest.raises(AssertionError): + check_env(env) + def test_high_dimension_action_space(): """ @@ -87,7 +98,7 @@ def test_high_dimension_action_space(): # Patch to avoid error def patched_step(_action): - return env.observation_space.sample(), 0.0, False, {} + return env.observation_space.sample(), 0.0, False, False, {} env.step = patched_step check_env(env) @@ -116,10 +127,10 @@ def test_non_default_spaces(new_obs_space): env = FakeImageEnv() env.observation_space = new_obs_space # Patch methods to avoid errors - env.reset = new_obs_space.sample + env.reset = lambda: (new_obs_space.sample(), {}) def patched_step(_action): - return new_obs_space.sample(), 0.0, False, {} + return new_obs_space.sample(), 0.0, False, False, {} env.step = patched_step with pytest.warns(UserWarning): @@ -155,14 +166,20 @@ def test_non_default_action_spaces(new_action_space): # No warnings for custom envs assert len(record) == 0 + # Change the action space env.action_space = new_action_space + low, high = new_action_space.low[0], new_action_space.high[0] # Unbounded action space throws an error, # the rest only warning if not np.all(np.isfinite(env.action_space.low)): with pytest.raises(AssertionError), pytest.warns(UserWarning): check_env(env) + # numpy >= 1.21 raises a ValueError + elif int(np.__version__.split(".")[1]) >= 21 and (low > high): + with pytest.raises(ValueError), pytest.warns(UserWarning): + check_env(env) else: with pytest.warns(UserWarning): check_env(env) @@ -176,7 +193,7 @@ def check_reset_assert_error(env, new_reset_return): """ def wrong_reset(): - return new_reset_return + return new_reset_return, {} # Patch the reset method with a wrong one env.reset = wrong_reset @@ -194,6 +211,11 @@ def test_common_failures_reset(): # The observation is not a numpy array check_reset_assert_error(env, 1) + # Return only obs (gym < 0.26) + env.reset = env.observation_space.sample + with pytest.raises(AssertionError): + check_env(env) + # Return not only the observation check_reset_assert_error(env, (env.observation_space.sample(), False)) @@ -206,10 +228,10 @@ def test_common_failures_reset(): wrong_obs = {**env.observation_space.sample(), "extra_key": None} check_reset_assert_error(env, wrong_obs) - obs = env.reset() + obs, _ = env.reset() def wrong_reset(self): - return {"img": obs["img"], "vec": obs["img"]} + return {"img": obs["img"], "vec": obs["img"]}, {} env.reset = types.MethodType(wrong_reset, env) with pytest.raises(AssertionError) as excinfo: @@ -242,33 +264,38 @@ def test_common_failures_step(): env = IdentityEnvBox() # Wrong shape for the observation - check_step_assert_error(env, (np.ones((4,)), 1.0, False, {})) + check_step_assert_error(env, (np.ones((4,)), 1.0, False, False, {})) # Obs is not a numpy array - check_step_assert_error(env, (1, 1.0, False, {})) + check_step_assert_error(env, (1, 1.0, False, False, {})) # Return a wrong reward - check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, {})) + check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, False, {})) # Info dict is not returned - check_step_assert_error(env, (env.observation_space.sample(), 0.0, False)) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, False)) + + # Truncated is not returned (gym < 0.26) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, {})) # Done is not a boolean - check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {})) - check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {})) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, False, {})) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, False, {})) + # Truncated is not a boolean + check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, 1.0, {})) env = SimpleMultiObsEnv() # Observation keys and observation space keys must match wrong_obs = env.observation_space.sample() wrong_obs.pop("img") - check_step_assert_error(env, (wrong_obs, 0.0, False, {})) + check_step_assert_error(env, (wrong_obs, 0.0, False, False, {})) wrong_obs = {**env.observation_space.sample(), "extra_key": None} - check_step_assert_error(env, (wrong_obs, 0.0, False, {})) + check_step_assert_error(env, (wrong_obs, 0.0, False, False, {})) - obs = env.reset() + obs, _ = env.reset() def wrong_step(self, action): - return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, {} + return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, False, {} env.step = types.MethodType(wrong_step, env) with pytest.raises(AssertionError) as excinfo: diff --git a/tests/test_gae.py b/tests/test_gae.py index 8e461ed7a..35f01ac77 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -1,3 +1,5 @@ +from typing import Optional + import gym import numpy as np import pytest @@ -19,20 +21,26 @@ def __init__(self, max_steps=8): def seed(self, seed): self.observation_space.seed(seed) - def reset(self): + def reset(self, seed: Optional[int] = None): + if seed is not None: + self.observation_space.seed(seed) self.n_steps = 0 - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): self.n_steps += 1 - done = False + done = truncated = False reward = 0.0 if self.n_steps >= self.max_steps: reward = 1.0 done = True + # To simplify GAE computation checks, + # we do not consider truncation here. + # Truncations are checked in InfiniteHorizonEnv + truncated = False - return self.observation_space.sample(), reward, done, {} + return self.observation_space.sample(), reward, done, truncated, {} class InfiniteHorizonEnv(gym.Env): @@ -43,13 +51,16 @@ def __init__(self, n_states=4): self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.current_state = 0 - def reset(self): + def reset(self, seed: Optional[int] = None): + if seed is not None: + super().reset(seed=seed) + self.current_state = 0 - return self.current_state + return self.current_state, {} def step(self, action): self.current_state = (self.current_state + 1) % self.n_states - return self.current_state, 1.0, False, {} + return self.current_state, 1.0, False, False, {} class CheckGAECallback(BaseCallback): diff --git a/tests/test_her.py b/tests/test_her.py index 888d36a6e..c1bc515ed 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -143,7 +143,7 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling): model.learn(total_timesteps=150) - obs = env.reset() + obs, _ = env.reset() observations = {key: [] for key in obs.keys()} for _ in range(10): @@ -237,7 +237,7 @@ def test_save_load_replay_buffer(tmp_path, recwarn, online_sampling, truncate_la train_freq=4, buffer_size=int(2e4), policy_kwargs=dict(net_arch=[64]), - seed=1, + seed=0, ) model.learn(200) if online_sampling: diff --git a/tests/test_identity.py b/tests/test_identity.py index f5bbc4946..66443b1b3 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -15,21 +15,17 @@ def test_discrete(model_class, env): env_ = DummyVecEnv([lambda: env]) kwargs = {} - n_steps = 3000 + n_steps = 2500 if model_class == DQN: kwargs = dict(learning_starts=0) - n_steps = 4000 # DQN only support discrete actions if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)): return - elif model_class == A2C: - # slightly higher budget - n_steps = 3500 - model = model_class("MlpPolicy", env_, gamma=0.4, seed=1, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", env_, gamma=0.4, seed=3, **kwargs).learn(n_steps) evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False) - obs = env.reset() + obs, _ = env.reset() assert np.shape(model.predict(obs)[0]) == np.shape(obs) @@ -38,9 +34,10 @@ def test_discrete(model_class, env): def test_continuous(model_class): env = IdentityEnvBox(eps=0.5) - n_steps = {A2C: 3500, PPO: 3000, SAC: 700, TD3: 500, DDPG: 500}[model_class] + n_steps = {A2C: 2000, PPO: 2500, SAC: 700, TD3: 500, DDPG: 500}[model_class] kwargs = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95) + if model_class in [TD3]: n_actions = 1 action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) diff --git a/tests/test_logger.py b/tests/test_logger.py index 516a622df..9bf05213a 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -1,4 +1,5 @@ import os +import sys import time from typing import Sequence from unittest import mock @@ -353,12 +354,12 @@ def __init__(self, delay: float = 0.01): self.action_space = gym.spaces.Discrete(2) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): time.sleep(self.delay) obs = self.observation_space.sample() - return obs, 0.0, True, {} + return obs, 0.0, True, False, {} class InMemoryLogger(Logger): @@ -409,3 +410,11 @@ def test_fps_no_div_zero(algo): with mock.patch("time.time_ns", lambda: 42.0): model = algo("MlpPolicy", "CartPole-v1") model.learn(total_timesteps=100) + + +def test_human_output_format_no_crash_on_same_keys_different_tags(): + o = HumanOutputFormat(sys.stdout, max_length=60) + o.write( + {"key1/foo": "value1", "key1/bar": "value2", "key2/bizz": "value3", "key2/foo": "value4"}, + {"key1/foo": None, "key2/bizz": None, "key1/bar": None, "key2/foo": None}, + ) diff --git a/tests/test_monitor.py b/tests/test_monitor.py index e5cb7f9bb..c580fcf49 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -13,7 +13,7 @@ def test_monitor(tmp_path): Test the monitor wrapper """ env = gym.make("CartPole-v1") - env.seed(0) + env.reset(seed=0) monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env = Monitor(env, monitor_file) monitor_env.reset() @@ -22,10 +22,10 @@ def test_monitor(tmp_path): ep_lengths = [] ep_len, ep_reward = 0, 0 for _ in range(total_steps): - _, reward, done, _ = monitor_env.step(monitor_env.action_space.sample()) + _, reward, done, truncated, _ = monitor_env.step(monitor_env.action_space.sample()) ep_len += 1 ep_reward += reward - if done: + if done or truncated: ep_rewards.append(ep_reward) ep_lengths.append(ep_len) monitor_env.reset() @@ -48,6 +48,15 @@ def test_monitor(tmp_path): assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline" os.remove(monitor_file) + # Check missing filename directories are created + monitor_dir = os.path.join(str(tmp_path), "missing-dir") + monitor_file = os.path.join(monitor_dir, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") + assert os.path.exists(monitor_dir) is False + _ = Monitor(env, monitor_file) + assert os.path.exists(monitor_dir) is True + os.remove(monitor_file) + os.rmdir(monitor_dir) + def test_monitor_load_results(tmp_path): """ @@ -55,7 +64,7 @@ def test_monitor_load_results(tmp_path): """ tmp_path = str(tmp_path) env1 = gym.make("CartPole-v1") - env1.seed(0) + env1.reset(seed=0) monitor_file1 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env1 = Monitor(env1, monitor_file1) @@ -66,8 +75,8 @@ def test_monitor_load_results(tmp_path): monitor_env1.reset() episode_count1 = 0 for _ in range(1000): - _, _, done, _ = monitor_env1.step(monitor_env1.action_space.sample()) - if done: + _, _, done, truncated, _ = monitor_env1.step(monitor_env1.action_space.sample()) + if done or truncated: episode_count1 += 1 monitor_env1.reset() @@ -75,7 +84,7 @@ def test_monitor_load_results(tmp_path): assert results_size1 == episode_count1 env2 = gym.make("CartPole-v1") - env2.seed(0) + env2.reset(seed=0) monitor_file2 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env2 = Monitor(env2, monitor_file2) monitor_files = get_monitor_files(tmp_path) @@ -89,8 +98,8 @@ def test_monitor_load_results(tmp_path): monitor_env2 = Monitor(env2, monitor_file2, override_existing=False) monitor_env2.reset() for _ in range(1000): - _, _, done, _ = monitor_env2.step(monitor_env2.action_space.sample()) - if done: + _, _, done, truncated, _ = monitor_env2.step(monitor_env2.action_space.sample()) + if done or truncated: episode_count2 += 1 monitor_env2.reset() diff --git a/tests/test_predict.py b/tests/test_predict.py index 89cdb0998..0a6855a02 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -29,10 +29,10 @@ def __init__(self): self.action_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, {} + return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, False, {} @pytest.mark.parametrize("model_class", MODEL_LIST) @@ -41,7 +41,7 @@ def test_auto_wrap(model_class): # Use different environment for DQN if model_class is DQN: - env_name = "CartPole-v0" + env_name = "CartPole-v1" else: env_name = "Pendulum-v1" env = gym.make(env_name) @@ -71,7 +71,7 @@ def test_predict(model_class, env_id, device): env = gym.make(env_id) vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)]) - obs = env.reset() + obs, _ = env.reset() action, _ = model.predict(obs) assert isinstance(action, np.ndarray) assert action.shape == env.action_space.shape @@ -97,7 +97,7 @@ def test_dqn_epsilon_greedy(): env = IdentityEnv(2) model = DQN("MlpPolicy", env) model.exploration_rate = 1.0 - obs = env.reset() + obs, _ = env.reset() # is vectorized should not crash with discrete obs action, _ = model.predict(obs, deterministic=False) assert env.action_space.contains(action) @@ -108,5 +108,5 @@ def test_subclassed_space_env(model_class): env = CustomSubClassedSpaceEnv() model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[32])) model.learn(300) - obs = env.reset() + obs, _ = env.reset() env.step(model.predict(obs)) diff --git a/tests/test_run.py b/tests/test_run.py index 655182da1..9dec724d7 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -26,6 +26,7 @@ def test_deterministic_pg(model_class, action_noise): verbose=1, create_eval_env=True, buffer_size=250, + gradient_steps=1, action_noise=action_noise, ) model.learn(total_timesteps=300, eval_freq=250) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index d7a74c5ee..f16d83a45 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -27,7 +27,7 @@ def select_env(model_class: BaseAlgorithm) -> gym.Env: if model_class == DQN: return IdentityEnv(10) else: - return IdentityEnvBox(10) + return IdentityEnvBox(-10, 10) @pytest.mark.parametrize("model_class", MODEL_LIST) @@ -174,6 +174,7 @@ def test_set_env(tmp_path, model_class): env = DummyVecEnv([lambda: select_env(model_class)]) env2 = DummyVecEnv([lambda: select_env(model_class)]) env3 = select_env(model_class) + env4 = DummyVecEnv([lambda: select_env(model_class) for _ in range(2)]) kwargs = {} if model_class in {DQN, DDPG, SAC, TD3}: @@ -199,6 +200,10 @@ def test_set_env(tmp_path, model_class): # learn again model.learn(total_timesteps=64) + # num_env must be the same + with pytest.raises(AssertionError): + model.set_env(env4) + # Keep the same env, disable reset model.set_env(model.get_env(), force_reset=False) assert model._last_obs is not None @@ -223,6 +228,11 @@ def test_set_env(tmp_path, model_class): model.learn(total_timesteps=64, reset_num_timesteps=False) assert model.num_timesteps == 3 * 64 + del model + # Load the model with a different number of environments + model = model_class.load(tmp_path / "test_save.zip", env=env4) + model.learn(total_timesteps=64) + # Clear saved file os.remove(tmp_path / "test_save.zip") diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 0696492d6..2f9c79606 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,3 +1,5 @@ +from typing import Optional + import gym import numpy as np import pytest @@ -13,11 +15,13 @@ def __init__(self, nvec): self.observation_space = gym.spaces.MultiDiscrete(nvec) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - def reset(self): - return self.observation_space.sample() + def reset(self, seed: Optional[int] = None): + if seed is not None: + super().reset(seed=seed) + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, False, {} + return self.observation_space.sample(), 0.0, False, False, {} class DummyMultiBinary(gym.Env): @@ -26,11 +30,13 @@ def __init__(self, n): self.observation_space = gym.spaces.MultiBinary(n) self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - def reset(self): - return self.observation_space.sample() + def reset(self, seed: Optional[int] = None): + if seed is not None: + super().reset(seed=seed) + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, False, {} + return self.observation_space.sample(), 0.0, False, False, {} class DummyMultidimensionalAction(gym.Env): @@ -40,10 +46,10 @@ def __init__(self): self.action_space = gym.spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, False, {} + return self.observation_space.sample(), 0.0, False, False, {} @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 2a9eade81..35d5ae1e1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -191,17 +191,18 @@ def __init__(self, env): self.needs_reset = True def step(self, action): - obs, reward, done, info = self.env.step(action) - self.needs_reset = done + obs, reward, done, truncated, info = self.env.step(action) + self.needs_reset = done or truncated self.last_obs = obs - return obs, reward, True, info + return obs, reward, True, truncated, info def reset(self, **kwargs): + info = {} if self.needs_reset: - obs = self.env.reset(**kwargs) + obs, info = self.env.reset(**kwargs) self.last_obs = obs self.needs_reset = False - return self.last_obs + return self.last_obs, info @pytest.mark.parametrize("n_envs", [1, 2, 5, 7]) @@ -235,7 +236,7 @@ def test_evaluate_policy_monitors(vec_env_class): # Also test VecEnvs n_eval_episodes = 3 n_envs = 2 - env_id = "CartPole-v0" + env_id = "CartPole-v1" model = A2C("MlpPolicy", env_id, seed=0) def make_eval_env(with_monitor, wrapper_class=gym.Wrapper): diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py index 962355782..f09a7aae4 100644 --- a/tests/test_vec_check_nan.py +++ b/tests/test_vec_check_nan.py @@ -24,11 +24,11 @@ def step(action): obs = float("inf") else: obs = 0 - return [obs], 0.0, False, {} + return [obs], 0.0, False, False, {} @staticmethod def reset(): - return [0.0] + return [0.0], {} def render(self, mode="human", close=False): pass diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index 93ea348b1..be357655e 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -2,6 +2,7 @@ import functools import itertools import multiprocessing +from typing import Optional import gym import numpy as np @@ -25,17 +26,19 @@ def __init__(self, space): self.current_step = 0 self.ep_length = 4 - def reset(self): + def reset(self, seed: Optional[int] = None): + if seed is not None: + self.seed(seed) self.current_step = 0 self._choose_next_state() - return self.state + return self.state, {} def step(self, action): reward = float(np.random.rand()) self._choose_next_state() self.current_step += 1 - done = self.current_step >= self.ep_length - return self.state, reward, done, {} + done = truncated = self.current_step >= self.ep_length + return self.state, reward, done, truncated, {} def _choose_next_state(self): self.state = self.observation_space.sample() @@ -144,13 +147,13 @@ def __init__(self, max_steps): def reset(self): self.current_step = 0 - return np.array([self.current_step], dtype="int") + return np.array([self.current_step], dtype="int"), {} def step(self, action): prev_step = self.current_step self.current_step += 1 - done = self.current_step >= self.max_steps - return np.array([prev_step], dtype="int"), 0.0, done, {} + done = truncated = self.current_step >= self.max_steps + return np.array([prev_step], dtype="int"), 0.0, done, truncated, {} @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) @@ -444,6 +447,23 @@ def make_monitored_env(): assert vec_env.env_is_wrapped(Monitor) == [False, True] +@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) +def test_backward_compat_seed(vec_env_class): + def make_env(): + env = CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + # Patch reset function to remove seed param + env.reset = lambda: (env.observation_space.sample(), {}) + env.seed = env.observation_space.seed + return env + + vec_env = vec_env_class([make_env for _ in range(N_ENVS)]) + vec_env.seed(3) + obs = vec_env.reset() + vec_env.seed(3) + new_obs = vec_env.reset() + assert np.allclose(new_obs, obs) + + @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) def test_vec_seeding(vec_env_class): def make_env(): diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index 5ccc33e12..bbf5e8d21 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -2,6 +2,7 @@ import json import os import uuid +import warnings import gym import pandas @@ -132,15 +133,18 @@ def test_vec_monitor_ppo(recwarn): """ Test the `VecMonitor` with PPO """ - env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) - env.seed(0) + # Remove Gym Warnings + warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="gym") + + env = DummyVecEnv([lambda: gym.make("CartPole-v1", disable_env_checker=True)]) + env.seed(seed=0) monitor_env = VecMonitor(env) model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu") model.learn(total_timesteps=250) # No warnings because using `VecMonitor` evaluate_policy(model, monitor_env) - assert len(recwarn) == 0 + assert len(recwarn) == 0, f"{[str(warning) for warning in recwarn]}" def test_vec_monitor_warn(): diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index a363e402d..a2c35f49b 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -1,5 +1,6 @@ import operator import warnings +from typing import Optional import gym import numpy as np @@ -34,14 +35,17 @@ def step(self, action): self.t += 1 index = (self.t + self.return_reward_idx) % len(self.returned_rewards) returned_value = self.returned_rewards[index] - return np.array([returned_value]), returned_value, self.t == len(self.returned_rewards), {} + done = truncated = self.t == len(self.returned_rewards) + return np.array([returned_value]), returned_value, done, truncated, {} - def reset(self): + def reset(self, seed: Optional[int] = None): + if seed is not None: + super().reset(seed=seed) self.t = 0 - return np.array([self.returned_rewards[self.return_reward_idx]]) + return np.array([self.returned_rewards[self.return_reward_idx]]), {} -class DummyDictEnv(gym.GoalEnv): +class DummyDictEnv(gym.Env): """ Dummy gym goal env for testing purposes """ @@ -57,14 +61,16 @@ def __init__(self): ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) - def reset(self): - return self.observation_space.sample() + def reset(self, seed: Optional[int] = None): + if seed is not None: + super().reset(seed=seed) + return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], {}) done = np.random.rand() > 0.8 - return obs, reward, done, {} + return obs, reward, done, False, {} def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> np.float32: distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1) @@ -87,13 +93,15 @@ def __init__(self): ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) - def reset(self): - return self.observation_space.sample() + def reset(self, seed: Optional[int] = None): + if seed is not None: + super().reset(seed=seed) + return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() done = np.random.rand() > 0.8 - return obs, 0.0, done, {} + return obs, 0.0, done, False, {} def allclose(obs_1, obs_2):