Skip to content

Commit

Permalink
Fix replay buffer device at load time (#20)
Browse files Browse the repository at this point in the history
* Fix replay buffer device at load time

* Fix imports

* Update version

* Reformat and add test

* Fix test

* Fix for mypy
  • Loading branch information
araffin authored Dec 13, 2023
1 parent 9bd4bca commit ba597ca
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
13 changes: 13 additions & 0 deletions sbx/common/off_policy_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import io
import pathlib
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import jax
Expand All @@ -9,6 +11,7 @@
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, Schedule
from stable_baselines3.common.utils import get_device


class OffPolicyAlgorithmJax(OffPolicyAlgorithm):
Expand Down Expand Up @@ -116,3 +119,13 @@ def _setup_model(self) -> None:
)
# Convert train freq parameter to TrainFreq object
self._convert_train_freq()

def load_replay_buffer(
self,
path: Union[str, pathlib.Path, io.BufferedIOBase],
truncate_last_traj: bool = True,
) -> None:
super().load_replay_buffer(path, truncate_last_traj)
# Override replay buffer device to be always cpu for conversion to numpy
assert self.replay_buffer is not None
self.replay_buffer.device = get_device("cpu")
2 changes: 1 addition & 1 deletion sbx/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.0
0.9.1
14 changes: 14 additions & 0 deletions tests/test_buffers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import pytest
import torch as th

from sbx import SAC


def test_force_cpu_device(tmp_path):
if not th.cuda.is_available():
pytest.skip("No CUDA device")
model = SAC("MlpPolicy", "Pendulum-v1", buffer_size=200)
assert model.replay_buffer.device == th.device("cpu")
model.save_replay_buffer(tmp_path / "replay")
model.load_replay_buffer(tmp_path / "replay")
assert model.replay_buffer.device == th.device("cpu")

0 comments on commit ba597ca

Please sign in to comment.