Skip to content

Commit 42595a5

Browse files
authored
Fix warning when loading a RecurrentPPO model (#255)
* Reformat configs * Fix warning when loading RecurrentPPO agent
1 parent 5c81398 commit 42595a5

File tree

6 files changed

+54
-50
lines changed

6 files changed

+54
-50
lines changed

.github/workflows/ci.yml

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ name: CI
55

66
on:
77
push:
8-
branches: [ master ]
8+
branches: [master]
99
pull_request:
10-
branches: [ master ]
10+
branches: [master]
1111

1212
jobs:
1313
build:
@@ -22,42 +22,42 @@ jobs:
2222
python-version: ["3.8", "3.9", "3.10", "3.11"]
2323

2424
steps:
25-
- uses: actions/checkout@v3
26-
- name: Set up Python ${{ matrix.python-version }}
27-
uses: actions/setup-python@v4
28-
with:
29-
python-version: ${{ matrix.python-version }}
30-
- name: Install dependencies
31-
run: |
32-
python -m pip install --upgrade pip
33-
# cpu version of pytorch
34-
pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cpu
25+
- uses: actions/checkout@v3
26+
- name: Set up Python ${{ matrix.python-version }}
27+
uses: actions/setup-python@v4
28+
with:
29+
python-version: ${{ matrix.python-version }}
30+
- name: Install dependencies
31+
run: |
32+
python -m pip install --upgrade pip
33+
# cpu version of pytorch
34+
pip install torch==2.1.1 --index-url https://download.pytorch.org/whl/cpu
3535
36-
# Install Atari Roms
37-
pip install autorom
38-
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
39-
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
40-
AutoROM --accept-license --source-file Roms.tar.gz
36+
# Install Atari Roms
37+
pip install autorom
38+
wget https://gist.githubusercontent.com/jjshoots/61b22aefce4456920ba99f2c36906eda/raw/00046ac3403768bfe45857610a3d333b8e35e026/Roms.tar.gz.b64
39+
base64 Roms.tar.gz.b64 --decode &> Roms.tar.gz
40+
AutoROM --accept-license --source-file Roms.tar.gz
4141
42-
# Install master version
43-
# and dependencies for docs and tests
44-
pip install "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
45-
pip install .
46-
# Use headless version
47-
pip install opencv-python-headless
42+
# Install master version
43+
# and dependencies for docs and tests
44+
pip install "stable_baselines3[extra_no_roms,tests,docs] @ git+https://github.com/DLR-RM/stable-baselines3"
45+
pip install .
46+
# Use headless version
47+
pip install opencv-python-headless
4848
49-
- name: Lint with ruff
50-
run: |
51-
make lint
52-
- name: Check codestyle
53-
run: |
54-
make check-codestyle
55-
- name: Build the doc
56-
run: |
57-
make doc
58-
- name: Type check
59-
run: |
60-
make type
61-
- name: Test with pytest
62-
run: |
63-
make pytest
49+
- name: Lint with ruff
50+
run: |
51+
make lint
52+
- name: Check codestyle
53+
run: |
54+
make check-codestyle
55+
- name: Build the doc
56+
run: |
57+
make doc
58+
- name: Type check
59+
run: |
60+
make type
61+
- name: Test with pytest
62+
run: |
63+
make pytest

docs/misc/changelog.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Changelog
44
==========
55

66

7-
Release 2.4.0a4 (WIP)
7+
Release 2.4.0a8 (WIP)
88
--------------------------
99

1010
Breaking Changes:
@@ -18,6 +18,7 @@ Bug Fixes:
1818
^^^^^^^^^^
1919
- Updated QR-DQN optimizer input to only include quantile_net parameters (@corentinlger)
2020
- Updated QR-DQN paper link in docs (@corentinlger)
21+
- Fixed a warning with PyTorch 2.4 when loading a `RecurrentPPO` model (You are using torch.load with weights_only=False)
2122

2223
Deprecations:
2324
^^^^^^^^^^^^^

pyproject.toml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ignore = ["B028", "RUF013"]
1212

1313
[tool.ruff.lint.per-file-ignores]
1414
# ClassVar, implicit optional check not needed for tests
15-
"./tests/*.py"= ["RUF012", "RUF013"]
15+
"./tests/*.py" = ["RUF012", "RUF013"]
1616

1717
[tool.ruff.lint.mccabe]
1818
# Unlike Flake8, ruff default to a complexity level of 10.
@@ -35,22 +35,22 @@ exclude = """(?x)(
3535

3636
[tool.pytest.ini_options]
3737
# Deterministic ordering for tests; useful for pytest-xdist.
38-
env = [
39-
"PYTHONHASHSEED=0"
40-
]
38+
env = ["PYTHONHASHSEED=0"]
4139

4240
filterwarnings = [
4341
# Tensorboard warnings
4442
"ignore::DeprecationWarning:tensorboard",
4543
]
46-
markers = [
47-
"slow: marks tests as slow (deselect with '-m \"not slow\"')"
48-
]
44+
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
4945

5046
[tool.coverage.run]
5147
disable_warnings = ["couldnt-parse"]
5248
branch = false
5349
omit = ["tests/*", "setup.py"]
5450

5551
[tool.coverage.report]
56-
exclude_lines = [ "pragma: no cover", "raise NotImplementedError()", "if typing.TYPE_CHECKING:"]
52+
exclude_lines = [
53+
"pragma: no cover",
54+
"raise NotImplementedError()",
55+
"if typing.TYPE_CHECKING:",
56+
]

sb3_contrib/common/maskable/policies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def predict(
304304
with th.no_grad():
305305
actions = self._predict(obs_tensor, deterministic=deterministic, action_masks=action_masks)
306306
# Convert to numpy
307-
actions = actions.cpu().numpy()
307+
actions = actions.cpu().numpy() # type: ignore[assignment]
308308

309309
if isinstance(self.action_space, spaces.Box):
310310
if self.squash_output:

sb3_contrib/ppo_recurrent/ppo_recurrent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from copy import deepcopy
2-
from typing import Any, ClassVar, Dict, Optional, Type, TypeVar, Union
2+
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union
33

44
import numpy as np
55
import torch as th
@@ -455,3 +455,6 @@ def learn(
455455
reset_num_timesteps=reset_num_timesteps,
456456
progress_bar=progress_bar,
457457
)
458+
459+
def _excluded_save_params(self) -> List[str]:
460+
return super()._excluded_save_params() + ["_last_lstm_states"] # noqa: RUF005

sb3_contrib/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.4.0a4
1+
2.4.0a8

0 commit comments

Comments
 (0)