diff --git a/CHANGELOG.md b/CHANGELOG.md
index 00b260c4..30f8e29b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,14 @@
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
+## [0.9.1] - 2023-01-17
+### Added
+- Utility for downloading models from Hugging Face Hub
+
+### Fixed
+- Initialization of agent components if they have not been defined
+- Manual trainer `train`/`eval` method default arguments
+
## [0.9.0] - 2023-01-13
### Added
- Support for Farama Gymnasium interface
diff --git a/README.md b/README.md
index 7313d5f1..5234fd1d 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,10 @@
-[![license](https://img.shields.io/pypi/l/skrl)](https://github.com/Toni-SM/skrl)
-[![docs](https://readthedocs.org/projects/skrl/badge/?version=latest)](https://skrl.readthedocs.io/en/latest/?badge=latest)
[![pypi](https://img.shields.io/pypi/v/skrl)](https://pypi.org/project/skrl)
-
+[](https://huggingface.co/skrl)
+![discussions](https://img.shields.io/github/discussions/Toni-SM/skrl)
+
+[![license](https://img.shields.io/github/license/Toni-SM/skrl)](https://github.com/Toni-SM/skrl)
+
+[![docs](https://readthedocs.org/projects/skrl/badge/?version=latest)](https://skrl.readthedocs.io/en/latest/?badge=latest)
[![pytest](https://github.com/Toni-SM/skrl/actions/workflows/python-test.yml/badge.svg)](https://github.com/Toni-SM/skrl/actions/workflows/python-test.yml)
[![pre-commit](https://github.com/Toni-SM/skrl/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/Toni-SM/skrl/actions/workflows/pre-commit.yml)
diff --git a/docs/source/index.rst b/docs/source/index.rst
index b948cdca..7493c920 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -1,6 +1,32 @@
SKRL - Reinforcement Learning library (|version|)
=================================================
+.. raw:: html
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
**skrl** is an open-source modular library for Reinforcement Learning written in Python (using `PyTorch `_) and designed with a focus on readability, simplicity, and transparency of algorithm implementation. In addition to supporting the OpenAI `Gym `_ / Farama `Gymnasium `_, `DeepMind `_ and other environment interfaces, it allows loading and configuring `NVIDIA Isaac Gym `_ and `NVIDIA Omniverse Isaac Gym `_ environments, enabling agents' simultaneous training by scopes (subsets of environments among all available environments), which may or may not share resources, in the same run
**Main features:**
@@ -196,8 +222,9 @@ Utils
Definition of helper functions and classes
* :doc:`Utilities `, e.g. setting the random seed
- * :doc:`Model instantiators `
* Memory and Tensorboard :doc:`file post-processing `
+ * :doc:`Model instantiators `
+ * :doc:`Hugging Face integration `
* :doc:`Isaac Gym utils `
* :doc:`Omniverse Isaac Gym utils `
@@ -207,7 +234,8 @@ Utils
:hidden:
modules/skrl.utils.utilities
- modules/skrl.utils.model_instantiators
modules/skrl.utils.postprocessing
+ modules/skrl.utils.model_instantiators
+ modules/skrl.utils.huggingface
modules/skrl.utils.isaacgym_utils
modules/skrl.utils.omniverse_isaacgym_utils
diff --git a/docs/source/intro/data.rst b/docs/source/intro/data.rst
index 1c740916..de53ff02 100644
--- a/docs/source/intro/data.rst
+++ b/docs/source/intro/data.rst
@@ -272,6 +272,30 @@ The following code snippets show how to load the checkpoints through the instant
# Load the checkpoint
policy.load("./runs/22-09-29_22-48-49-816281_DDPG/checkpoints/2500_policy.pt")
+In addition, it is possible to load, through the library utilities, trained agent checkpoints from the Hugging Face Hub (`huggingface.co/skrl `_). See the :doc:`Hugging Face integration <../modules/skrl.utils.huggingface>` for more information.
+
+.. tabs::
+
+ .. tab:: Agent (from Hugging Face Hub)
+
+ .. code-block:: python
+ :emphasize-lines: 2, 13-14
+
+ from skrl.agents.torch.ppo import PPO
+ from skrl.utils.huggingface import download_model_from_huggingface
+
+ # Instantiate the agent
+ agent = PPO(models=models, # models dict
+ memory=memory, # memory instance, or None if not required
+ cfg=agent_cfg, # configuration dict (preprocessors, learning rate schedulers, etc.)
+ observation_space=env.observation_space,
+ action_space=env.action_space,
+ device=env.device)
+
+ # Load the checkpoint from Hugging Face Hub
+ path = download_model_from_huggingface("skrl/OmniIsaacGymEnvs-Cartpole-PPO")
+ agent.load(path)
+
Migrating external checkpoints
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/docs/source/modules/skrl.utils.huggingface.rst b/docs/source/modules/skrl.utils.huggingface.rst
new file mode 100644
index 00000000..37064d6b
--- /dev/null
+++ b/docs/source/modules/skrl.utils.huggingface.rst
@@ -0,0 +1,20 @@
+Hugging Face integration
+========================
+
+.. raw:: html
+
+
+
+Download model from Hugging Face Hub
+------------------------------------
+
+Several skrl-trained models (agent checkpoints) for different environments/tasks of Gym/Gymnasium, Isaac Gym, Omniverse Isaac Gym, etc. are available in the Hugging Face Hub
+
+These models can be used as comparison benchmarks, for collecting environment transitions in memory (for offline reinforcement learning, e.g.) or for pre-initialization of agents for performing similar tasks, among others
+
+Visit the `skrl organization on the Hugging Face Hub `_ to access publicly available models!
+
+API
+"""
+
+.. autofunction:: skrl.utils.huggingface.download_model_from_huggingface
diff --git a/docs/source/modules/skrl.utils.utilities.rst b/docs/source/modules/skrl.utils.utilities.rst
index 6a095267..5db53aa5 100644
--- a/docs/source/modules/skrl.utils.utilities.rst
+++ b/docs/source/modules/skrl.utils.utilities.rst
@@ -2,7 +2,7 @@ Utilities
=========
.. contents:: Table of Contents
- :depth: 2
+ :depth: 1
:local:
:backlinks: none
diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py
index 2c413dff..1e14aea5 100644
--- a/skrl/agents/torch/amp/amp.py
+++ b/skrl/agents/torch/amp/amp.py
@@ -239,12 +239,13 @@ def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
"log_prob", "values", "returns", "advantages", "amp_states", "next_values"]
# create tensors for motion dataset and reply buffer
- self.motion_dataset.create_tensor(name="states", size=self.amp_observation_space, dtype=torch.float32)
- self.reply_buffer.create_tensor(name="states", size=self.amp_observation_space, dtype=torch.float32)
+ if self.motion_dataset is not None:
+ self.motion_dataset.create_tensor(name="states", size=self.amp_observation_space, dtype=torch.float32)
+ self.reply_buffer.create_tensor(name="states", size=self.amp_observation_space, dtype=torch.float32)
- # initialize motion dataset
- for _ in range(math.ceil(self.motion_dataset.memory_size / self._amp_batch_size)):
- self.motion_dataset.add_samples(states=self.collect_reference_motions(self._amp_batch_size))
+ # initialize motion dataset
+ for _ in range(math.ceil(self.motion_dataset.memory_size / self._amp_batch_size)):
+ self.motion_dataset.add_samples(states=self.collect_reference_motions(self._amp_batch_size))
# create temporary variables needed for storage and computation
self._current_log_prob = None
diff --git a/skrl/agents/torch/ddpg/ddpg.py b/skrl/agents/torch/ddpg/ddpg.py
index adf3fed0..63ee766e 100644
--- a/skrl/agents/torch/ddpg/ddpg.py
+++ b/skrl/agents/torch/ddpg/ddpg.py
@@ -193,8 +193,9 @@ def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
self._rnn_initial_states["policy"].append(torch.zeros(size, dtype=torch.float32, device=self.device))
# clip noise bounds
- self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device)
- self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device)
+ if self.action_space is not None:
+ self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device)
+ self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device)
# backward compatibility: torch < 1.9 clamp method does not support tensors
self._backward_compatibility = tuple(map(int, (torch.__version__.split(".")[:2]))) < (1, 9)
diff --git a/skrl/agents/torch/dqn/ddqn.py b/skrl/agents/torch/dqn/ddqn.py
index 90c144ce..b5d7c714 100644
--- a/skrl/agents/torch/dqn/ddqn.py
+++ b/skrl/agents/torch/dqn/ddqn.py
@@ -139,7 +139,7 @@ def __init__(self,
if self._learning_rate_scheduler is not None:
self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"])
- self.checkpoint_modules["optimizer"] = self.optimizer
+ self.checkpoint_modules["optimizer"] = self.optimizer
# set up preprocessors
if self._state_preprocessor:
diff --git a/skrl/agents/torch/td3/td3.py b/skrl/agents/torch/td3/td3.py
index 57441eda..0be4f6ec 100644
--- a/skrl/agents/torch/td3/td3.py
+++ b/skrl/agents/torch/td3/td3.py
@@ -211,8 +211,9 @@ def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
self._rnn_initial_states["policy"].append(torch.zeros(size, dtype=torch.float32, device=self.device))
# clip noise bounds
- self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device)
- self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device)
+ if self.action_space is not None:
+ self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device)
+ self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device)
# backward compatibility: torch < 1.9 clamp method does not support tensors
self._backward_compatibility = tuple(map(int, (torch.__version__.split(".")[:2]))) < (1, 9)
diff --git a/skrl/trainers/torch/manual.py b/skrl/trainers/torch/manual.py
index 8b3af490..fa361998 100644
--- a/skrl/trainers/torch/manual.py
+++ b/skrl/trainers/torch/manual.py
@@ -50,11 +50,12 @@ def __init__(self,
else:
self.agents.init(trainer_cfg=self.cfg)
+ self._timestep = 0
self._progress = None
self.states = None
- def train(self, timestep: int, timesteps: Optional[int] = None) -> None:
+ def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> None:
"""Execute a training iteration
This method executes the following steps once:
@@ -68,11 +69,15 @@ def train(self, timestep: int, timesteps: Optional[int] = None) -> None:
- Reset environments
:param timestep: Current timestep
- :type timestep: int
+ :type timestep: int, optional (default: None).
+ If None, the current timestep will be carried by an internal variable
:param timesteps: Total number of timesteps (default: None).
If None, the total number of timesteps is obtained from the trainer's config
:type timesteps: int, optional
"""
+ if timestep is None:
+ self._timestep += 1
+ timestep = self._timestep
timesteps = self.timesteps if timesteps is None else timesteps
if self._progress is None:
@@ -157,7 +162,7 @@ def train(self, timestep: int, timesteps: Optional[int] = None) -> None:
self.states.copy_(next_states)
- def eval(self, timestep: int, timesteps: Optional[int] = None) -> None:
+ def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> None:
"""Evaluate the agents sequentially
This method executes the following steps in loop:
@@ -168,11 +173,15 @@ def eval(self, timestep: int, timesteps: Optional[int] = None) -> None:
- Reset environments
:param timestep: Current timestep
- :type timestep: int
+ :type timestep: int, optional (default: None).
+ If None, the current timestep will be carried by an internal variable
:param timesteps: Total number of timesteps (default: None).
If None, the total number of timesteps is obtained from the trainer's config
:type timesteps: int, optional
"""
+ if timestep is None:
+ self._timestep += 1
+ timestep = self._timestep
timesteps = self.timesteps if timesteps is None else timesteps
if self._progress is None:
diff --git a/skrl/utils/huggingface.py b/skrl/utils/huggingface.py
new file mode 100644
index 00000000..39a64b08
--- /dev/null
+++ b/skrl/utils/huggingface.py
@@ -0,0 +1,45 @@
+from skrl import logger, __version__
+
+
+def download_model_from_huggingface(repo_id: str, filename: str = "agent.pt") -> str:
+ """Download a model from Hugging Face Hub
+
+ :param repo_id: Hugging Face user or organization name and a repo name separated by a ``/``
+ :type repo_id: str
+ :param filename: The name of the model file in the repo (default: ``"agent.pt"``)
+ :type filename: str, optional
+
+ :raises ImportError: The Hugging Face Hub package (huggingface-hub) is not installed
+ :raises huggingface_hub.utils._errors.HfHubHTTPError: Any HTTP error raised in Hugging Face Hub
+
+ :return: Local path of file or if networking is off, last version of file cached on disk
+ :rtype: str
+
+ Example::
+
+ # download trained agent from the skrl organization (https://huggingface.co/skrl)
+ >>> from skrl.utils.huggingface import download_model_from_huggingface
+ >>> download_model_from_huggingface("skrl/OmniIsaacGymEnvs-Cartpole-PPO")
+ '/home/user/.cache/huggingface/hub/models--skrl--OmniIsaacGymEnvs-Cartpole-PPO/snapshots/892e629903de6bf3ef102ae760406a5dd0f6f873/agent.pt'
+
+ # download model (e.g. "policy.pth") from another user/organization (e.g. "org/ddpg-Pendulum-v1")
+ >>> from skrl.utils.huggingface import download_model_from_huggingface
+ >>> download_model_from_huggingface("org/ddpg-Pendulum-v1", "policy.pth")
+ '/home/user/.cache/huggingface/hub/models--org--ddpg-Pendulum-v1/snapshots/b44ee96f93ff2e296156b002a2ca4646e197ba32/policy.pth'
+ """
+ try:
+ import huggingface_hub
+ except ImportError:
+ logger.error("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it")
+ huggingface_hub = None
+
+ if huggingface_hub is None:
+ raise ImportError("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it")
+
+ # download and cache the model from Hugging Face Hub
+ downloaded_model_file = huggingface_hub.hf_hub_download(repo_id=repo_id,
+ filename=filename,
+ library_name="skrl",
+ library_version=__version__)
+
+ return downloaded_model_file
diff --git a/skrl/version.txt b/skrl/version.txt
index ac39a106..f374f666 100644
--- a/skrl/version.txt
+++ b/skrl/version.txt
@@ -1 +1 @@
-0.9.0
+0.9.1
diff --git a/tests/test_agents.py b/tests/test_agents.py
new file mode 100644
index 00000000..635a2222
--- /dev/null
+++ b/tests/test_agents.py
@@ -0,0 +1,59 @@
+import pytest
+import warnings
+import hypothesis
+import hypothesis.strategies as st
+
+import torch
+
+from skrl.agents.torch import Agent
+
+from skrl.agents.torch.a2c import A2C
+from skrl.agents.torch.amp import AMP
+from skrl.agents.torch.cem import CEM
+from skrl.agents.torch.ddpg import DDPG
+from skrl.agents.torch.dqn import DQN, DDQN
+from skrl.agents.torch.ppo import PPO
+from skrl.agents.torch.q_learning import Q_LEARNING
+from skrl.agents.torch.sac import SAC
+from skrl.agents.torch.sarsa import SARSA
+from skrl.agents.torch.td3 import TD3
+from skrl.agents.torch.trpo import TRPO
+
+from .utils import DummyModel
+
+
+@pytest.fixture
+def classes_and_kwargs():
+ return [(A2C, {"models": {"policy": DummyModel()}}),
+ (AMP, {"models": {"policy": DummyModel()}}),
+ (CEM, {"models": {"policy": DummyModel()}}),
+ (DDPG, {"models": {"policy": DummyModel()}}),
+ (DQN, {"models": {"policy": DummyModel()}}),
+ (DDQN, {"models": {"policy": DummyModel()}}),
+ (PPO, {"models": {"policy": DummyModel()}}),
+ (Q_LEARNING, {"models": {"policy": DummyModel()}}),
+ (SAC, {"models": {"policy": DummyModel()}}),
+ (SARSA, {"models": {"policy": DummyModel()}}),
+ (TD3, {"models": {"policy": DummyModel()}}),
+ (TRPO, {"models": {"policy": DummyModel()}})]
+
+
+def test_agent(capsys, classes_and_kwargs):
+ for klass, kwargs in classes_and_kwargs:
+ cfg = {"learning_starts": 1,
+ "experiment": {"write_interval": 0}}
+ agent: Agent = klass(cfg=cfg, **kwargs)
+
+ agent.init()
+ agent.pre_interaction(timestep=0, timesteps=1)
+ # agent.act(None, timestep=0, timestesps=1)
+ agent.record_transition(states=torch.tensor([]),
+ actions=torch.tensor([]),
+ rewards=torch.tensor([]),
+ next_states=torch.tensor([]),
+ terminated=torch.tensor([]),
+ truncated=torch.tensor([]),
+ infos={},
+ timestep=0,
+ timesteps=1)
+ agent.post_interaction(timestep=0, timesteps=1)
diff --git a/tests/test_envs.py b/tests/test_envs.py
new file mode 100644
index 00000000..e552f31a
--- /dev/null
+++ b/tests/test_envs.py
@@ -0,0 +1,36 @@
+import pytest
+import warnings
+import hypothesis
+import hypothesis.strategies as st
+
+import torch
+
+from skrl.envs.torch import Wrapper
+
+from skrl.envs.torch import wrap_env
+
+from .utils import DummyEnv
+
+
+@pytest.fixture
+def classes_and_kwargs():
+ return []
+
+
+@pytest.mark.parametrize("wrapper", ["gym", "gymnasium", "dm", "robosuite", \
+ "isaacgym-preview2", "isaacgym-preview3", "isaacgym-preview4", "omniverse-isaacgym"])
+def test_wrap_env(capsys, classes_and_kwargs, wrapper):
+ env = DummyEnv(num_envs=1)
+
+ try:
+ env: Wrapper = wrap_env(env=env, wrapper=wrapper)
+ except ValueError as e:
+ warnings.warn(f"{e}. This test will be skipped for '{wrapper}'")
+ except ModuleNotFoundError as e:
+ warnings.warn(f"{e}. The '{wrapper}' wrapper module is not found. This test will be skipped")
+
+ env.observation_space
+ env.action_space
+ env.state_space
+ env.num_envs
+ env.device
diff --git a/tests/test_memories.py b/tests/test_memories.py
index 99204cf8..f4edee80 100644
--- a/tests/test_memories.py
+++ b/tests/test_memories.py
@@ -8,6 +8,7 @@
import torch
from skrl.memories.torch import Memory
+
from skrl.memories.torch import RandomMemory
diff --git a/tests/test_model_instantiators.py b/tests/test_model_instantiators.py
index 923e09d2..ea12ed85 100644
--- a/tests/test_model_instantiators.py
+++ b/tests/test_model_instantiators.py
@@ -16,9 +16,12 @@
@pytest.fixture
def classes_and_kwargs():
- return []
+ return [(categorical_model, {}),
+ (deterministic_model, {}),
+ (gaussian_model, {}),
+ (multivariate_gaussian_model, {})]
-@pytest.mark.parametrize("device", [None, "cpu", "cuda:0"])
-def test_device(capsys, classes_and_kwargs, device):
- _device = torch.device(device) if device is not None else torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+def test_models(capsys, classes_and_kwargs):
+ for klass, kwargs in classes_and_kwargs:
+ model: Model = klass(observation_space=1, action_space=1, device="cpu", **kwargs)
diff --git a/tests/test_resources_noises.py b/tests/test_resources_noises.py
index f9dcedba..549e6171 100644
--- a/tests/test_resources_noises.py
+++ b/tests/test_resources_noises.py
@@ -6,6 +6,7 @@
import torch
from skrl.resources.noises.torch import Noise
+
from skrl.resources.noises.torch import GaussianNoise
from skrl.resources.noises.torch import OrnsteinUhlenbeckNoise
diff --git a/tests/test_trainers.py b/tests/test_trainers.py
new file mode 100644
index 00000000..74cd2110
--- /dev/null
+++ b/tests/test_trainers.py
@@ -0,0 +1,40 @@
+import pytest
+import warnings
+import hypothesis
+import hypothesis.strategies as st
+
+import torch
+
+from skrl.trainers.torch import Trainer
+
+from skrl.trainers.torch import ManualTrainer
+from skrl.trainers.torch import ParallelTrainer
+from skrl.trainers.torch import SequentialTrainer
+
+from .utils import DummyEnv, DummyAgent
+
+
+@pytest.fixture
+def classes_and_kwargs():
+ return [(ManualTrainer, {"cfg": {"timesteps": 100}}),
+ (ParallelTrainer, {"cfg": {"timesteps": 100}}),
+ (SequentialTrainer, {"cfg": {"timesteps": 100}})]
+
+
+def test_train(capsys, classes_and_kwargs):
+ env = DummyEnv(num_envs=1)
+ agent = DummyAgent()
+
+ for klass, kwargs in classes_and_kwargs:
+ trainer: Trainer = klass(env, agents=agent, **kwargs)
+
+ trainer.train()
+
+def test_eval(capsys, classes_and_kwargs):
+ env = DummyEnv(num_envs=1)
+ agent = DummyAgent()
+
+ for klass, kwargs in classes_and_kwargs:
+ trainer: Trainer = klass(env, agents=agent, **kwargs)
+
+ trainer.eval()
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 00000000..dd43ec62
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,95 @@
+import random
+import gymnasium as gym
+
+import torch
+
+
+class DummyEnv(gym.Env):
+ def __init__(self, num_envs, device = "cpu"):
+ self.num_envs = num_envs
+ self.device = torch.device(device)
+ self.action_space = gym.spaces.Discrete(2)
+ self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,))
+
+ def __getattr__(self, key):
+ if key in ["_spec_to_space", "observation_spec"]:
+ return lambda *args, **kwargs: None
+ return None
+
+ def step(self, action):
+ observation = self.observation_space.sample()
+ reward = random.random()
+ terminated = random.random() > 0.95
+ truncated = random.random() > 0.95
+
+ observation = torch.tensor(observation, dtype=torch.float32).view(self.num_envs, -1)
+ reward = torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1)
+ terminated = torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1)
+ truncated = torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1)
+
+ return observation, reward, terminated, truncated, {}
+
+ def reset(self):
+ observation = self.observation_space.sample()
+ observation = torch.tensor(observation, dtype=torch.float32).view(self.num_envs, -1)
+ return observation, {}
+
+ def render(self, *args, **kwargs):
+ pass
+
+ def close(self, *args, **kwargs):
+ pass
+
+
+class _DummyBaseAgent:
+ def __init__(self):
+ pass
+
+ def record_transition(self, states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps):
+ pass
+
+ def pre_interaction(self, timestep, timesteps):
+ pass
+
+ def post_interaction(self, timestep, timesteps):
+ pass
+
+ def set_running_mode(self, mode):
+ pass
+
+
+class DummyAgent(_DummyBaseAgent):
+ def __init__(self):
+ super().__init__()
+
+ def init(self, trainer_cfg=None):
+ pass
+
+ def act(self, states, timestep, timesteps):
+ return torch.tensor([]), None, {}
+
+ def record_transition(self, states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps):
+ pass
+
+ def pre_interaction(self, timestep, timesteps):
+ pass
+
+ def post_interaction(self, timestep, timesteps):
+ pass
+
+
+class DummyModel(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ self.device = torch.device("cpu")
+ self.layer = torch.nn.Linear(1, 1)
+
+ def set_mode(self, *args, **kwargs):
+ pass
+
+ def get_specification(self, *args, **kwargs):
+ return {}
+
+ def act(self, *args, **kwargs):
+ return torch.tensor([]), None, {}