From 3d9a975af220a5362d6ee8bf297869ddaaaf325e Mon Sep 17 00:00:00 2001 From: Joe Ksiazek Date: Wed, 2 Oct 2024 04:22:57 -0400 Subject: [PATCH] Fix QRDQN loading `target_update_interval` (#259) * Fix QRDQN loading target_update_interval * Update changelog * Update version --------- Co-authored-by: Antonin RAFFIN --- CONTRIBUTING.md | 4 ++-- docs/misc/changelog.rst | 3 ++- sb3_contrib/qrdqn/qrdqn.py | 9 ++++----- sb3_contrib/version.txt | 2 +- tests/test_save_load.py | 12 ++++++++++++ 5 files changed, 21 insertions(+), 9 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9013df7e..8edbb305 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -152,13 +152,13 @@ To run tests with `pytest`: make pytest ``` -Type checking with `pytype` and `mypy`: +Type checking with `mypy`: ``` make type ``` -Codestyle check with `black`, `isort` and `flake8`: +Codestyle check with `black` and `ruff`: ``` make check-codestyle diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index d63060c7..0aec7598 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 2.4.0a8 (WIP) +Release 2.4.0a9 (WIP) -------------------------- Breaking Changes: @@ -19,6 +19,7 @@ Bug Fixes: - Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger) - Updated QR-DQN paper link in docs (@corentinlger) - Fixed a warning with PyTorch 2.4 when loading a `RecurrentPPO` model (You are using torch.load with weights_only=False) +- Fixed loading QRDQN changes `target_update_interval` (@jak3122) Deprecations: ^^^^^^^^^^^^^ diff --git a/sb3_contrib/qrdqn/qrdqn.py b/sb3_contrib/qrdqn/qrdqn.py index 5271ac8c..129cef79 100644 --- a/sb3_contrib/qrdqn/qrdqn.py +++ b/sb3_contrib/qrdqn/qrdqn.py @@ -153,8 +153,7 @@ def _setup_model(self) -> None: self.exploration_schedule = get_linear_fn( self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction ) - # Account for multiple environments - # each call to step() corresponds to n_envs transitions + if self.n_envs > 1: if self.n_envs > self.target_update_interval: warnings.warn( @@ -164,8 +163,6 @@ def _setup_model(self) -> None: f"which corresponds to {self.n_envs} steps." ) - self.target_update_interval = max(self.target_update_interval // self.n_envs, 1) - def _create_aliases(self) -> None: self.quantile_net = self.policy.quantile_net self.quantile_net_target = self.policy.quantile_net_target @@ -177,7 +174,9 @@ def _on_step(self) -> None: This method is called in ``collect_rollouts()`` after each step in the environment. """ self._n_calls += 1 - if self._n_calls % self.target_update_interval == 0: + # Account for multiple environments + # each call to step() corresponds to n_envs transitions + if self._n_calls % max(self.target_update_interval // self.n_envs, 1) == 0: polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau) # Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996 polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0) diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index ee717ba1..636c433a 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -2.4.0a8 +2.4.0a9 diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 502d2394..3b8f038c 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -8,6 +8,7 @@ import pytest import torch as th from stable_baselines3.common.base_class import BaseAlgorithm +from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv @@ -481,3 +482,14 @@ def test_save_load_pytorch_var(tmp_path): assert model.log_ent_coef is None # Check that the entropy coefficient is still the same assert th.allclose(ent_coef_before, ent_coef_after) + + +def test_dqn_target_update_interval(tmp_path): + # `target_update_interval` should not change when reloading the model. See GH Issue #258. + env = make_vec_env(env_id="CartPole-v1", n_envs=2) + model = QRDQN("MlpPolicy", env, verbose=1, target_update_interval=100) + model.save(tmp_path / "dqn_cartpole") + del model + model = QRDQN.load(tmp_path / "dqn_cartpole") + os.remove(tmp_path / "dqn_cartpole.zip") + assert model.target_update_interval == 100