From 2cb64c8685b53be17a0fa3125c48fc0c190231a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 16 Jan 2023 17:25:32 +0100 Subject: [PATCH 01/14] Add test utils --- tests/utils.py | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 tests/utils.py diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..9f902964 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,73 @@ +import random +import gymnasium as gym + +import torch + + +class DummyEnv(gym.Env): + def __init__(self, num_envs, device = "cpu"): + self.num_envs = num_envs + self.device = torch.device(device) + self.action_space = gym.spaces.Discrete(2) + self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,)) + + def step(self, action): + observation = self.observation_space.sample() + reward = random.random() + terminated = random.random() > 0.95 + truncated = random.random() > 0.95 + + observation = torch.tensor(observation, dtype=torch.float32).view(self.num_envs, -1) + reward = torch.tensor(reward, device=self.device, dtype=torch.float32).view(self.num_envs, -1) + terminated = torch.tensor(terminated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) + truncated = torch.tensor(truncated, device=self.device, dtype=torch.bool).view(self.num_envs, -1) + + return observation, reward, terminated, truncated, {} + + def reset(self): + observation = self.observation_space.sample() + observation = torch.tensor(observation, dtype=torch.float32).view(self.num_envs, -1) + return observation, {} + + def render(self, *args, **kwargs): + pass + + def close(self, *args, **kwargs): + pass + + +class _DummyBaseAgent: + def __init__(self): + pass + + def record_transition(self, states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps): + pass + + def pre_interaction(self, timestep, timesteps): + pass + + def post_interaction(self, timestep, timesteps): + pass + + def set_running_mode(self, mode): + pass + + +class DummyAgent(_DummyBaseAgent): + def __init__(self): + super().__init__() + + def init(self, trainer_cfg=None): + pass + + def act(self, states, timestep, timesteps): + return torch.tensor([]), None, {} + + def record_transition(self, states, actions, rewards, next_states, terminated, truncated, infos, timestep, timesteps): + pass + + def pre_interaction(self, timestep, timesteps): + pass + + def post_interaction(self, timestep, timesteps): + pass From 6b587d0f04647c4805624f260c4695e864daac92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 16 Jan 2023 17:27:16 +0100 Subject: [PATCH 02/14] Add trainers tests --- tests/test_trainers.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/test_trainers.py diff --git a/tests/test_trainers.py b/tests/test_trainers.py new file mode 100644 index 00000000..f2e61e7f --- /dev/null +++ b/tests/test_trainers.py @@ -0,0 +1,39 @@ +import pytest +import warnings +import hypothesis +import hypothesis.strategies as st + +import torch + +from skrl.trainers.torch import Trainer +from skrl.trainers.torch import ManualTrainer +from skrl.trainers.torch import ParallelTrainer +from skrl.trainers.torch import SequentialTrainer + +from .utils import DummyEnv, DummyAgent + + +@pytest.fixture +def classes_and_kwargs(): + return [(ManualTrainer, {"cfg": {"timesteps": 100}}), + (ParallelTrainer, {"cfg": {"timesteps": 100}}), + (SequentialTrainer, {"cfg": {"timesteps": 100}})] + + +def test_train(capsys, classes_and_kwargs): + env = DummyEnv(num_envs=1) + agent = DummyAgent() + + for klass, kwargs in classes_and_kwargs: + trainer: Trainer = klass(env, agents=agent, **kwargs) + + trainer.train() + +def test_eval(capsys, classes_and_kwargs): + env = DummyEnv(num_envs=1) + agent = DummyAgent() + + for klass, kwargs in classes_and_kwargs: + trainer: Trainer = klass(env, agents=agent, **kwargs) + + trainer.eval() From 8391cc9c9404b8751cbd1a20adaf6b94acfa96e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 16 Jan 2023 17:31:41 +0100 Subject: [PATCH 03/14] Set the manual trainer train/eval method's timestep argument as optional --- skrl/trainers/torch/manual.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/skrl/trainers/torch/manual.py b/skrl/trainers/torch/manual.py index 8b3af490..fa361998 100644 --- a/skrl/trainers/torch/manual.py +++ b/skrl/trainers/torch/manual.py @@ -50,11 +50,12 @@ def __init__(self, else: self.agents.init(trainer_cfg=self.cfg) + self._timestep = 0 self._progress = None self.states = None - def train(self, timestep: int, timesteps: Optional[int] = None) -> None: + def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> None: """Execute a training iteration This method executes the following steps once: @@ -68,11 +69,15 @@ def train(self, timestep: int, timesteps: Optional[int] = None) -> None: - Reset environments :param timestep: Current timestep - :type timestep: int + :type timestep: int, optional (default: None). + If None, the current timestep will be carried by an internal variable :param timesteps: Total number of timesteps (default: None). If None, the total number of timesteps is obtained from the trainer's config :type timesteps: int, optional """ + if timestep is None: + self._timestep += 1 + timestep = self._timestep timesteps = self.timesteps if timesteps is None else timesteps if self._progress is None: @@ -157,7 +162,7 @@ def train(self, timestep: int, timesteps: Optional[int] = None) -> None: self.states.copy_(next_states) - def eval(self, timestep: int, timesteps: Optional[int] = None) -> None: + def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> None: """Evaluate the agents sequentially This method executes the following steps in loop: @@ -168,11 +173,15 @@ def eval(self, timestep: int, timesteps: Optional[int] = None) -> None: - Reset environments :param timestep: Current timestep - :type timestep: int + :type timestep: int, optional (default: None). + If None, the current timestep will be carried by an internal variable :param timesteps: Total number of timesteps (default: None). If None, the total number of timesteps is obtained from the trainer's config :type timesteps: int, optional """ + if timestep is None: + self._timestep += 1 + timestep = self._timestep timesteps = self.timesteps if timesteps is None else timesteps if self._progress is None: From 0dcc2b69d11d0da094351c0832f29e16b1af72ae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 16 Jan 2023 18:29:26 +0100 Subject: [PATCH 04/14] Add badges to docs and README --- README.md | 9 ++++++--- docs/source/index.rst | 26 ++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 7313d5f1..5234fd1d 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,10 @@ -[![license](https://img.shields.io/pypi/l/skrl)](https://github.com/Toni-SM/skrl) -[![docs](https://readthedocs.org/projects/skrl/badge/?version=latest)](https://skrl.readthedocs.io/en/latest/?badge=latest) [![pypi](https://img.shields.io/pypi/v/skrl)](https://pypi.org/project/skrl) -   +[](https://huggingface.co/skrl) +![discussions](https://img.shields.io/github/discussions/Toni-SM/skrl) +
+[![license](https://img.shields.io/github/license/Toni-SM/skrl)](https://github.com/Toni-SM/skrl) +     +[![docs](https://readthedocs.org/projects/skrl/badge/?version=latest)](https://skrl.readthedocs.io/en/latest/?badge=latest) [![pytest](https://github.com/Toni-SM/skrl/actions/workflows/python-test.yml/badge.svg)](https://github.com/Toni-SM/skrl/actions/workflows/python-test.yml) [![pre-commit](https://github.com/Toni-SM/skrl/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/Toni-SM/skrl/actions/workflows/pre-commit.yml) diff --git a/docs/source/index.rst b/docs/source/index.rst index b948cdca..e732c164 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,6 +1,32 @@ SKRL - Reinforcement Learning library (|version|) ================================================= +.. raw:: html + + + pypi + + + huggingface + + + discussions + +
+ + license + +      + + docs + + + pytest + + + pre-commit + + **skrl** is an open-source modular library for Reinforcement Learning written in Python (using `PyTorch `_) and designed with a focus on readability, simplicity, and transparency of algorithm implementation. In addition to supporting the OpenAI `Gym `_ / Farama `Gymnasium `_, `DeepMind `_ and other environment interfaces, it allows loading and configuring `NVIDIA Isaac Gym `_ and `NVIDIA Omniverse Isaac Gym `_ environments, enabling agents' simultaneous training by scopes (subsets of environments among all available environments), which may or may not share resources, in the same run **Main features:** From 266933631962539ff020a30436671c12afda3656 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 16 Jan 2023 22:45:42 +0100 Subject: [PATCH 05/14] Add testing files for other library components --- tests/test_agents.py | 59 ++++++++++++++++++++++++++++++++++ tests/test_envs.py | 34 ++++++++++++++++++++ tests/test_memories.py | 1 + tests/test_resources_noises.py | 1 + tests/test_trainers.py | 1 + tests/utils.py | 22 +++++++++++++ 6 files changed, 118 insertions(+) create mode 100644 tests/test_agents.py create mode 100644 tests/test_envs.py diff --git a/tests/test_agents.py b/tests/test_agents.py new file mode 100644 index 00000000..635a2222 --- /dev/null +++ b/tests/test_agents.py @@ -0,0 +1,59 @@ +import pytest +import warnings +import hypothesis +import hypothesis.strategies as st + +import torch + +from skrl.agents.torch import Agent + +from skrl.agents.torch.a2c import A2C +from skrl.agents.torch.amp import AMP +from skrl.agents.torch.cem import CEM +from skrl.agents.torch.ddpg import DDPG +from skrl.agents.torch.dqn import DQN, DDQN +from skrl.agents.torch.ppo import PPO +from skrl.agents.torch.q_learning import Q_LEARNING +from skrl.agents.torch.sac import SAC +from skrl.agents.torch.sarsa import SARSA +from skrl.agents.torch.td3 import TD3 +from skrl.agents.torch.trpo import TRPO + +from .utils import DummyModel + + +@pytest.fixture +def classes_and_kwargs(): + return [(A2C, {"models": {"policy": DummyModel()}}), + (AMP, {"models": {"policy": DummyModel()}}), + (CEM, {"models": {"policy": DummyModel()}}), + (DDPG, {"models": {"policy": DummyModel()}}), + (DQN, {"models": {"policy": DummyModel()}}), + (DDQN, {"models": {"policy": DummyModel()}}), + (PPO, {"models": {"policy": DummyModel()}}), + (Q_LEARNING, {"models": {"policy": DummyModel()}}), + (SAC, {"models": {"policy": DummyModel()}}), + (SARSA, {"models": {"policy": DummyModel()}}), + (TD3, {"models": {"policy": DummyModel()}}), + (TRPO, {"models": {"policy": DummyModel()}})] + + +def test_agent(capsys, classes_and_kwargs): + for klass, kwargs in classes_and_kwargs: + cfg = {"learning_starts": 1, + "experiment": {"write_interval": 0}} + agent: Agent = klass(cfg=cfg, **kwargs) + + agent.init() + agent.pre_interaction(timestep=0, timesteps=1) + # agent.act(None, timestep=0, timestesps=1) + agent.record_transition(states=torch.tensor([]), + actions=torch.tensor([]), + rewards=torch.tensor([]), + next_states=torch.tensor([]), + terminated=torch.tensor([]), + truncated=torch.tensor([]), + infos={}, + timestep=0, + timesteps=1) + agent.post_interaction(timestep=0, timesteps=1) diff --git a/tests/test_envs.py b/tests/test_envs.py new file mode 100644 index 00000000..242e0a61 --- /dev/null +++ b/tests/test_envs.py @@ -0,0 +1,34 @@ +import pytest +import warnings +import hypothesis +import hypothesis.strategies as st + +import torch + +from skrl.envs.torch import Wrapper + +from skrl.envs.torch import wrap_env + +from .utils import DummyEnv + + +@pytest.fixture +def classes_and_kwargs(): + return [] + + +@pytest.mark.parametrize("wrapper", ["gym", "gymnasium", "dm", "robosuite", \ + "isaacgym-preview2", "isaacgym-preview3", "isaacgym-preview4", "omniverse-isaacgym"]) +def test_wrap_env(capsys, classes_and_kwargs, wrapper): + env = DummyEnv(num_envs=1) + + try: + env: Wrapper = wrap_env(env=env, wrapper=wrapper) + except ValueError as e: + warnings.warn(f"{e}. This test will be skipped for '{wrapper}'") + + env.observation_space + env.action_space + env.state_space + env.num_envs + env.device diff --git a/tests/test_memories.py b/tests/test_memories.py index 99204cf8..f4edee80 100644 --- a/tests/test_memories.py +++ b/tests/test_memories.py @@ -8,6 +8,7 @@ import torch from skrl.memories.torch import Memory + from skrl.memories.torch import RandomMemory diff --git a/tests/test_resources_noises.py b/tests/test_resources_noises.py index f9dcedba..549e6171 100644 --- a/tests/test_resources_noises.py +++ b/tests/test_resources_noises.py @@ -6,6 +6,7 @@ import torch from skrl.resources.noises.torch import Noise + from skrl.resources.noises.torch import GaussianNoise from skrl.resources.noises.torch import OrnsteinUhlenbeckNoise diff --git a/tests/test_trainers.py b/tests/test_trainers.py index f2e61e7f..74cd2110 100644 --- a/tests/test_trainers.py +++ b/tests/test_trainers.py @@ -6,6 +6,7 @@ import torch from skrl.trainers.torch import Trainer + from skrl.trainers.torch import ManualTrainer from skrl.trainers.torch import ParallelTrainer from skrl.trainers.torch import SequentialTrainer diff --git a/tests/utils.py b/tests/utils.py index 9f902964..dd43ec62 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,6 +11,11 @@ def __init__(self, num_envs, device = "cpu"): self.action_space = gym.spaces.Discrete(2) self.observation_space = gym.spaces.Box(low=-1, high=1, shape=(2,)) + def __getattr__(self, key): + if key in ["_spec_to_space", "observation_spec"]: + return lambda *args, **kwargs: None + return None + def step(self, action): observation = self.observation_space.sample() reward = random.random() @@ -71,3 +76,20 @@ def pre_interaction(self, timestep, timesteps): def post_interaction(self, timestep, timesteps): pass + + +class DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + + self.device = torch.device("cpu") + self.layer = torch.nn.Linear(1, 1) + + def set_mode(self, *args, **kwargs): + pass + + def get_specification(self, *args, **kwargs): + return {} + + def act(self, *args, **kwargs): + return torch.tensor([]), None, {} From 91bb6f134afa073adda3a56dc44508ac475ddfb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 16 Jan 2023 22:48:17 +0100 Subject: [PATCH 06/14] Initialize motion dataset if not None --- skrl/agents/torch/amp/amp.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py index 2c413dff..1e14aea5 100644 --- a/skrl/agents/torch/amp/amp.py +++ b/skrl/agents/torch/amp/amp.py @@ -239,12 +239,13 @@ def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None: "log_prob", "values", "returns", "advantages", "amp_states", "next_values"] # create tensors for motion dataset and reply buffer - self.motion_dataset.create_tensor(name="states", size=self.amp_observation_space, dtype=torch.float32) - self.reply_buffer.create_tensor(name="states", size=self.amp_observation_space, dtype=torch.float32) + if self.motion_dataset is not None: + self.motion_dataset.create_tensor(name="states", size=self.amp_observation_space, dtype=torch.float32) + self.reply_buffer.create_tensor(name="states", size=self.amp_observation_space, dtype=torch.float32) - # initialize motion dataset - for _ in range(math.ceil(self.motion_dataset.memory_size / self._amp_batch_size)): - self.motion_dataset.add_samples(states=self.collect_reference_motions(self._amp_batch_size)) + # initialize motion dataset + for _ in range(math.ceil(self.motion_dataset.memory_size / self._amp_batch_size)): + self.motion_dataset.add_samples(states=self.collect_reference_motions(self._amp_batch_size)) # create temporary variables needed for storage and computation self._current_log_prob = None From 0ddf758ba0a7c467716006d0d4064de5953b3a87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 16 Jan 2023 22:50:05 +0100 Subject: [PATCH 07/14] Get action space limits if not None --- skrl/agents/torch/ddpg/ddpg.py | 5 +++-- skrl/agents/torch/td3/td3.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/skrl/agents/torch/ddpg/ddpg.py b/skrl/agents/torch/ddpg/ddpg.py index adf3fed0..63ee766e 100644 --- a/skrl/agents/torch/ddpg/ddpg.py +++ b/skrl/agents/torch/ddpg/ddpg.py @@ -193,8 +193,9 @@ def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None: self._rnn_initial_states["policy"].append(torch.zeros(size, dtype=torch.float32, device=self.device)) # clip noise bounds - self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device) - self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device) + if self.action_space is not None: + self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device) + self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device) # backward compatibility: torch < 1.9 clamp method does not support tensors self._backward_compatibility = tuple(map(int, (torch.__version__.split(".")[:2]))) < (1, 9) diff --git a/skrl/agents/torch/td3/td3.py b/skrl/agents/torch/td3/td3.py index 57441eda..0be4f6ec 100644 --- a/skrl/agents/torch/td3/td3.py +++ b/skrl/agents/torch/td3/td3.py @@ -211,8 +211,9 @@ def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None: self._rnn_initial_states["policy"].append(torch.zeros(size, dtype=torch.float32, device=self.device)) # clip noise bounds - self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device) - self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device) + if self.action_space is not None: + self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device) + self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device) # backward compatibility: torch < 1.9 clamp method does not support tensors self._backward_compatibility = tuple(map(int, (torch.__version__.split(".")[:2]))) < (1, 9) From 30e2b8d11f9be4e8fa5a3e67bb7750d63aaf89b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 16 Jan 2023 22:51:27 +0100 Subject: [PATCH 08/14] Storage optimizer for checkpointing if not None --- skrl/agents/torch/dqn/ddqn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skrl/agents/torch/dqn/ddqn.py b/skrl/agents/torch/dqn/ddqn.py index 90c144ce..b5d7c714 100644 --- a/skrl/agents/torch/dqn/ddqn.py +++ b/skrl/agents/torch/dqn/ddqn.py @@ -139,7 +139,7 @@ def __init__(self, if self._learning_rate_scheduler is not None: self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"]) - self.checkpoint_modules["optimizer"] = self.optimizer + self.checkpoint_modules["optimizer"] = self.optimizer # set up preprocessors if self._state_preprocessor: From 5e0ad993741932b5c36807633a10afe36e0d2def Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Mon, 16 Jan 2023 22:54:54 +0100 Subject: [PATCH 09/14] Add utility for download models from Hugging Face --- skrl/utils/huggingface.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 skrl/utils/huggingface.py diff --git a/skrl/utils/huggingface.py b/skrl/utils/huggingface.py new file mode 100644 index 00000000..98b280c5 --- /dev/null +++ b/skrl/utils/huggingface.py @@ -0,0 +1,30 @@ +from skrl import logger, __version__ + + +def download_model_from_huggingface(repo_id: str, filename: str = "agent.pt") -> str: + """Download a model from Hugging Face Hub + + :param repo_id: Hugging Face user or organization name and a repo name separated by a ``/`` + :type repo_id: str + :param filename: The name of the model file in the repo (default: ``"agent.pt"``) + :type filename: str, optional + + :return: Local path of file or if networking is off, last version of file cached on disk + :rtype: str + """ + try: + import huggingface_hub + except ImportError: + logger.error("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it") + huggingface_hub = None + + if huggingface_hub is None: + raise ImportError("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it") + + # download and cache the model from Hugging Face Hub + downloaded_model_file = huggingface_hub.hf_hub_download(repo_id=repo_id, + filename=filename, + library_name="skrl", + library_version=__version__) + + return downloaded_model_file From 22c7515ce44b7c29fd4f5af5aa057b2eed3279d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Tue, 17 Jan 2023 09:58:01 +0100 Subject: [PATCH 10/14] Update CHANGELOG --- CHANGELOG.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 00b260c4..30f8e29b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,14 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [0.9.1] - 2023-01-17 +### Added +- Utility for downloading models from Hugging Face Hub + +### Fixed +- Initialization of agent components if they have not been defined +- Manual trainer `train`/`eval` method default arguments + ## [0.9.0] - 2023-01-13 ### Added - Support for Farama Gymnasium interface From ec3d6ebfc1cc31a1c23737ba89655673e4401179 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Tue, 17 Jan 2023 09:58:20 +0100 Subject: [PATCH 11/14] Increase PATCH version --- skrl/version.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/skrl/version.txt b/skrl/version.txt index ac39a106..f374f666 100644 --- a/skrl/version.txt +++ b/skrl/version.txt @@ -1 +1 @@ -0.9.0 +0.9.1 From 3231a8de8d1958dbcd985896ec69fe8debec1e35 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Tue, 17 Jan 2023 10:08:00 +0100 Subject: [PATCH 12/14] Update tests --- tests/test_envs.py | 2 ++ tests/test_model_instantiators.py | 11 +++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_envs.py b/tests/test_envs.py index 242e0a61..e552f31a 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -26,6 +26,8 @@ def test_wrap_env(capsys, classes_and_kwargs, wrapper): env: Wrapper = wrap_env(env=env, wrapper=wrapper) except ValueError as e: warnings.warn(f"{e}. This test will be skipped for '{wrapper}'") + except ModuleNotFoundError as e: + warnings.warn(f"{e}. The '{wrapper}' wrapper module is not found. This test will be skipped") env.observation_space env.action_space diff --git a/tests/test_model_instantiators.py b/tests/test_model_instantiators.py index 923e09d2..ea12ed85 100644 --- a/tests/test_model_instantiators.py +++ b/tests/test_model_instantiators.py @@ -16,9 +16,12 @@ @pytest.fixture def classes_and_kwargs(): - return [] + return [(categorical_model, {}), + (deterministic_model, {}), + (gaussian_model, {}), + (multivariate_gaussian_model, {})] -@pytest.mark.parametrize("device", [None, "cpu", "cuda:0"]) -def test_device(capsys, classes_and_kwargs, device): - _device = torch.device(device) if device is not None else torch.device("cuda:0" if torch.cuda.is_available() else "cpu") +def test_models(capsys, classes_and_kwargs): + for klass, kwargs in classes_and_kwargs: + model: Model = klass(observation_space=1, action_space=1, device="cpu", **kwargs) From 21971613a09a066498d5d89ec9b5357c63333ce1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Tue, 17 Jan 2023 12:35:39 +0100 Subject: [PATCH 13/14] Add Hugging Face integration to docs --- docs/source/index.rst | 6 +++-- docs/source/intro/data.rst | 24 +++++++++++++++++++ .../source/modules/skrl.utils.huggingface.rst | 20 ++++++++++++++++ docs/source/modules/skrl.utils.utilities.rst | 2 +- 4 files changed, 49 insertions(+), 3 deletions(-) create mode 100644 docs/source/modules/skrl.utils.huggingface.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index e732c164..7493c920 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -222,8 +222,9 @@ Utils Definition of helper functions and classes * :doc:`Utilities `, e.g. setting the random seed - * :doc:`Model instantiators ` * Memory and Tensorboard :doc:`file post-processing ` + * :doc:`Model instantiators ` + * :doc:`Hugging Face integration ` * :doc:`Isaac Gym utils ` * :doc:`Omniverse Isaac Gym utils ` @@ -233,7 +234,8 @@ Utils :hidden: modules/skrl.utils.utilities - modules/skrl.utils.model_instantiators modules/skrl.utils.postprocessing + modules/skrl.utils.model_instantiators + modules/skrl.utils.huggingface modules/skrl.utils.isaacgym_utils modules/skrl.utils.omniverse_isaacgym_utils diff --git a/docs/source/intro/data.rst b/docs/source/intro/data.rst index 1c740916..de53ff02 100644 --- a/docs/source/intro/data.rst +++ b/docs/source/intro/data.rst @@ -272,6 +272,30 @@ The following code snippets show how to load the checkpoints through the instant # Load the checkpoint policy.load("./runs/22-09-29_22-48-49-816281_DDPG/checkpoints/2500_policy.pt") +In addition, it is possible to load, through the library utilities, trained agent checkpoints from the Hugging Face Hub (`huggingface.co/skrl `_). See the :doc:`Hugging Face integration <../modules/skrl.utils.huggingface>` for more information. + +.. tabs:: + + .. tab:: Agent (from Hugging Face Hub) + + .. code-block:: python + :emphasize-lines: 2, 13-14 + + from skrl.agents.torch.ppo import PPO + from skrl.utils.huggingface import download_model_from_huggingface + + # Instantiate the agent + agent = PPO(models=models, # models dict + memory=memory, # memory instance, or None if not required + cfg=agent_cfg, # configuration dict (preprocessors, learning rate schedulers, etc.) + observation_space=env.observation_space, + action_space=env.action_space, + device=env.device) + + # Load the checkpoint from Hugging Face Hub + path = download_model_from_huggingface("skrl/OmniIsaacGymEnvs-Cartpole-PPO") + agent.load(path) + Migrating external checkpoints ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/modules/skrl.utils.huggingface.rst b/docs/source/modules/skrl.utils.huggingface.rst new file mode 100644 index 00000000..37064d6b --- /dev/null +++ b/docs/source/modules/skrl.utils.huggingface.rst @@ -0,0 +1,20 @@ +Hugging Face integration +======================== + +.. raw:: html + +
+ +Download model from Hugging Face Hub +------------------------------------ + +Several skrl-trained models (agent checkpoints) for different environments/tasks of Gym/Gymnasium, Isaac Gym, Omniverse Isaac Gym, etc. are available in the Hugging Face Hub + +These models can be used as comparison benchmarks, for collecting environment transitions in memory (for offline reinforcement learning, e.g.) or for pre-initialization of agents for performing similar tasks, among others + +Visit the `skrl organization on the Hugging Face Hub `_ to access publicly available models! + +API +""" + +.. autofunction:: skrl.utils.huggingface.download_model_from_huggingface diff --git a/docs/source/modules/skrl.utils.utilities.rst b/docs/source/modules/skrl.utils.utilities.rst index 6a095267..5db53aa5 100644 --- a/docs/source/modules/skrl.utils.utilities.rst +++ b/docs/source/modules/skrl.utils.utilities.rst @@ -2,7 +2,7 @@ Utilities ========= .. contents:: Table of Contents - :depth: 2 + :depth: 1 :local: :backlinks: none From 1dad7041a0feb095b5a0f1a33ede8fd38f378045 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Tue, 17 Jan 2023 12:36:14 +0100 Subject: [PATCH 14/14] Improve function docstring --- skrl/utils/huggingface.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/skrl/utils/huggingface.py b/skrl/utils/huggingface.py index 98b280c5..39a64b08 100644 --- a/skrl/utils/huggingface.py +++ b/skrl/utils/huggingface.py @@ -9,8 +9,23 @@ def download_model_from_huggingface(repo_id: str, filename: str = "agent.pt") -> :param filename: The name of the model file in the repo (default: ``"agent.pt"``) :type filename: str, optional + :raises ImportError: The Hugging Face Hub package (huggingface-hub) is not installed + :raises huggingface_hub.utils._errors.HfHubHTTPError: Any HTTP error raised in Hugging Face Hub + :return: Local path of file or if networking is off, last version of file cached on disk :rtype: str + + Example:: + + # download trained agent from the skrl organization (https://huggingface.co/skrl) + >>> from skrl.utils.huggingface import download_model_from_huggingface + >>> download_model_from_huggingface("skrl/OmniIsaacGymEnvs-Cartpole-PPO") + '/home/user/.cache/huggingface/hub/models--skrl--OmniIsaacGymEnvs-Cartpole-PPO/snapshots/892e629903de6bf3ef102ae760406a5dd0f6f873/agent.pt' + + # download model (e.g. "policy.pth") from another user/organization (e.g. "org/ddpg-Pendulum-v1") + >>> from skrl.utils.huggingface import download_model_from_huggingface + >>> download_model_from_huggingface("org/ddpg-Pendulum-v1", "policy.pth") + '/home/user/.cache/huggingface/hub/models--org--ddpg-Pendulum-v1/snapshots/b44ee96f93ff2e296156b002a2ca4646e197ba32/policy.pth' """ try: import huggingface_hub