Skip to content

Commit

Permalink
Merge pull request #49 from Toni-SM/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
Toni-SM authored Jan 17, 2023
2 parents 41bf409 + 1dad704 commit fbb19d9
Show file tree
Hide file tree
Showing 20 changed files with 400 additions and 25 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [0.9.1] - 2023-01-17
### Added
- Utility for downloading models from Hugging Face Hub

### Fixed
- Initialization of agent components if they have not been defined
- Manual trainer `train`/`eval` method default arguments

## [0.9.0] - 2023-01-13
### Added
- Support for Farama Gymnasium interface
Expand Down
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
[![license](https://img.shields.io/pypi/l/skrl)](https://github.com/Toni-SM/skrl)
[![docs](https://readthedocs.org/projects/skrl/badge/?version=latest)](https://skrl.readthedocs.io/en/latest/?badge=latest)
[![pypi](https://img.shields.io/pypi/v/skrl)](https://pypi.org/project/skrl)
<span>&nbsp;&nbsp;</span>
[<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Models-Huggingface-F8D521">](https://huggingface.co/skrl)
![discussions](https://img.shields.io/github/discussions/Toni-SM/skrl)
<br>
[![license](https://img.shields.io/github/license/Toni-SM/skrl)](https://github.com/Toni-SM/skrl)
<span>&nbsp;&nbsp;&nbsp;&nbsp;</span>
[![docs](https://readthedocs.org/projects/skrl/badge/?version=latest)](https://skrl.readthedocs.io/en/latest/?badge=latest)
[![pytest](https://github.com/Toni-SM/skrl/actions/workflows/python-test.yml/badge.svg)](https://github.com/Toni-SM/skrl/actions/workflows/python-test.yml)
[![pre-commit](https://github.com/Toni-SM/skrl/actions/workflows/pre-commit.yml/badge.svg)](https://github.com/Toni-SM/skrl/actions/workflows/pre-commit.yml)

Expand Down
32 changes: 30 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,32 @@
SKRL - Reinforcement Learning library (|version|)
=================================================

.. raw:: html

<a href="https://pypi.org/project/skrl">
<img alt="pypi" src="https://img.shields.io/pypi/v/skrl">
</a>
<a href="https://huggingface.co/skrl">
<img alt="huggingface" src="https://img.shields.io/badge/%F0%9F%A4%97%20Models-Huggingface-F8D521">
</a>
<a href="https://github.com/Toni-SM/skrl/discussions">
<img alt="discussions" src="https://img.shields.io/github/discussions/Toni-SM/skrl">
</a>
<br>
<a href="https://github.com/Toni-SM/skrl/blob/main/LICENSE">
<img alt="license" src="https://img.shields.io/github/license/Toni-SM/skrl">
</a>
&nbsp;&nbsp;&nbsp;&nbsp;
<a href="https://skrl.readthedocs.io">
<img alt="docs" src="https://readthedocs.org/projects/skrl/badge/?version=latest">
</a>
<a href="https://github.com/Toni-SM/skrl/actions/workflows/python-test.yml">
<img alt="pytest" src="https://github.com/Toni-SM/skrl/actions/workflows/python-test.yml/badge.svg">
</a>
<a href="https://github.com/Toni-SM/skrl/actions/workflows/pre-commit.yml">
<img alt="pre-commit" src="https://github.com/Toni-SM/skrl/actions/workflows/pre-commit.yml/badge.svg">
</a>

**skrl** is an open-source modular library for Reinforcement Learning written in Python (using `PyTorch <https://pytorch.org/>`_) and designed with a focus on readability, simplicity, and transparency of algorithm implementation. In addition to supporting the OpenAI `Gym <https://www.gymlibrary.dev>`_ / Farama `Gymnasium <https://gymnasium.farama.org/>`_, `DeepMind <https://github.com/deepmind/dm_env>`_ and other environment interfaces, it allows loading and configuring `NVIDIA Isaac Gym <https://developer.nvidia.com/isaac-gym>`_ and `NVIDIA Omniverse Isaac Gym <https://docs.omniverse.nvidia.com/app_isaacsim/app_isaacsim/tutorial_gym_isaac_gym.html>`_ environments, enabling agents' simultaneous training by scopes (subsets of environments among all available environments), which may or may not share resources, in the same run

**Main features:**
Expand Down Expand Up @@ -196,8 +222,9 @@ Utils
Definition of helper functions and classes

* :doc:`Utilities <modules/skrl.utils.utilities>`, e.g. setting the random seed
* :doc:`Model instantiators <modules/skrl.utils.model_instantiators>`
* Memory and Tensorboard :doc:`file post-processing <modules/skrl.utils.postprocessing>`
* :doc:`Model instantiators <modules/skrl.utils.model_instantiators>`
* :doc:`Hugging Face integration <modules/skrl.utils.huggingface>`
* :doc:`Isaac Gym utils <modules/skrl.utils.isaacgym_utils>`
* :doc:`Omniverse Isaac Gym utils <modules/skrl.utils.omniverse_isaacgym_utils>`

Expand All @@ -207,7 +234,8 @@ Utils
:hidden:

modules/skrl.utils.utilities
modules/skrl.utils.model_instantiators
modules/skrl.utils.postprocessing
modules/skrl.utils.model_instantiators
modules/skrl.utils.huggingface
modules/skrl.utils.isaacgym_utils
modules/skrl.utils.omniverse_isaacgym_utils
24 changes: 24 additions & 0 deletions docs/source/intro/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,30 @@ The following code snippets show how to load the checkpoints through the instant
# Load the checkpoint
policy.load("./runs/22-09-29_22-48-49-816281_DDPG/checkpoints/2500_policy.pt")
In addition, it is possible to load, through the library utilities, trained agent checkpoints from the Hugging Face Hub (`huggingface.co/skrl <https://huggingface.co/skrl>`_). See the :doc:`Hugging Face integration <../modules/skrl.utils.huggingface>` for more information.

.. tabs::

.. tab:: Agent (from Hugging Face Hub)

.. code-block:: python
:emphasize-lines: 2, 13-14
from skrl.agents.torch.ppo import PPO
from skrl.utils.huggingface import download_model_from_huggingface
# Instantiate the agent
agent = PPO(models=models, # models dict
memory=memory, # memory instance, or None if not required
cfg=agent_cfg, # configuration dict (preprocessors, learning rate schedulers, etc.)
observation_space=env.observation_space,
action_space=env.action_space,
device=env.device)
# Load the checkpoint from Hugging Face Hub
path = download_model_from_huggingface("skrl/OmniIsaacGymEnvs-Cartpole-PPO")
agent.load(path)
Migrating external checkpoints
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
20 changes: 20 additions & 0 deletions docs/source/modules/skrl.utils.huggingface.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
Hugging Face integration
========================

.. raw:: html

<hr>

Download model from Hugging Face Hub
------------------------------------

Several skrl-trained models (agent checkpoints) for different environments/tasks of Gym/Gymnasium, Isaac Gym, Omniverse Isaac Gym, etc. are available in the Hugging Face Hub

These models can be used as comparison benchmarks, for collecting environment transitions in memory (for offline reinforcement learning, e.g.) or for pre-initialization of agents for performing similar tasks, among others

Visit the `skrl organization on the Hugging Face Hub <https://huggingface.co/skrl>`_ to access publicly available models!

API
"""

.. autofunction:: skrl.utils.huggingface.download_model_from_huggingface
2 changes: 1 addition & 1 deletion docs/source/modules/skrl.utils.utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ Utilities
=========

.. contents:: Table of Contents
:depth: 2
:depth: 1
:local:
:backlinks: none

Expand Down
11 changes: 6 additions & 5 deletions skrl/agents/torch/amp/amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,12 +239,13 @@ def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
"log_prob", "values", "returns", "advantages", "amp_states", "next_values"]

# create tensors for motion dataset and reply buffer
self.motion_dataset.create_tensor(name="states", size=self.amp_observation_space, dtype=torch.float32)
self.reply_buffer.create_tensor(name="states", size=self.amp_observation_space, dtype=torch.float32)
if self.motion_dataset is not None:
self.motion_dataset.create_tensor(name="states", size=self.amp_observation_space, dtype=torch.float32)
self.reply_buffer.create_tensor(name="states", size=self.amp_observation_space, dtype=torch.float32)

# initialize motion dataset
for _ in range(math.ceil(self.motion_dataset.memory_size / self._amp_batch_size)):
self.motion_dataset.add_samples(states=self.collect_reference_motions(self._amp_batch_size))
# initialize motion dataset
for _ in range(math.ceil(self.motion_dataset.memory_size / self._amp_batch_size)):
self.motion_dataset.add_samples(states=self.collect_reference_motions(self._amp_batch_size))

# create temporary variables needed for storage and computation
self._current_log_prob = None
Expand Down
5 changes: 3 additions & 2 deletions skrl/agents/torch/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,9 @@ def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
self._rnn_initial_states["policy"].append(torch.zeros(size, dtype=torch.float32, device=self.device))

# clip noise bounds
self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device)
self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device)
if self.action_space is not None:
self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device)
self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device)

# backward compatibility: torch < 1.9 clamp method does not support tensors
self._backward_compatibility = tuple(map(int, (torch.__version__.split(".")[:2]))) < (1, 9)
Expand Down
2 changes: 1 addition & 1 deletion skrl/agents/torch/dqn/ddqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __init__(self,
if self._learning_rate_scheduler is not None:
self.scheduler = self._learning_rate_scheduler(self.optimizer, **self.cfg["learning_rate_scheduler_kwargs"])

self.checkpoint_modules["optimizer"] = self.optimizer
self.checkpoint_modules["optimizer"] = self.optimizer

# set up preprocessors
if self._state_preprocessor:
Expand Down
5 changes: 3 additions & 2 deletions skrl/agents/torch/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,9 @@ def init(self, trainer_cfg: Optional[Dict[str, Any]] = None) -> None:
self._rnn_initial_states["policy"].append(torch.zeros(size, dtype=torch.float32, device=self.device))

# clip noise bounds
self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device)
self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device)
if self.action_space is not None:
self.clip_actions_min = torch.tensor(self.action_space.low, device=self.device)
self.clip_actions_max = torch.tensor(self.action_space.high, device=self.device)

# backward compatibility: torch < 1.9 clamp method does not support tensors
self._backward_compatibility = tuple(map(int, (torch.__version__.split(".")[:2]))) < (1, 9)
Expand Down
17 changes: 13 additions & 4 deletions skrl/trainers/torch/manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ def __init__(self,
else:
self.agents.init(trainer_cfg=self.cfg)

self._timestep = 0
self._progress = None

self.states = None

def train(self, timestep: int, timesteps: Optional[int] = None) -> None:
def train(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> None:
"""Execute a training iteration
This method executes the following steps once:
Expand All @@ -68,11 +69,15 @@ def train(self, timestep: int, timesteps: Optional[int] = None) -> None:
- Reset environments
:param timestep: Current timestep
:type timestep: int
:type timestep: int, optional (default: None).
If None, the current timestep will be carried by an internal variable
:param timesteps: Total number of timesteps (default: None).
If None, the total number of timesteps is obtained from the trainer's config
:type timesteps: int, optional
"""
if timestep is None:
self._timestep += 1
timestep = self._timestep
timesteps = self.timesteps if timesteps is None else timesteps

if self._progress is None:
Expand Down Expand Up @@ -157,7 +162,7 @@ def train(self, timestep: int, timesteps: Optional[int] = None) -> None:
self.states.copy_(next_states)


def eval(self, timestep: int, timesteps: Optional[int] = None) -> None:
def eval(self, timestep: Optional[int] = None, timesteps: Optional[int] = None) -> None:
"""Evaluate the agents sequentially
This method executes the following steps in loop:
Expand All @@ -168,11 +173,15 @@ def eval(self, timestep: int, timesteps: Optional[int] = None) -> None:
- Reset environments
:param timestep: Current timestep
:type timestep: int
:type timestep: int, optional (default: None).
If None, the current timestep will be carried by an internal variable
:param timesteps: Total number of timesteps (default: None).
If None, the total number of timesteps is obtained from the trainer's config
:type timesteps: int, optional
"""
if timestep is None:
self._timestep += 1
timestep = self._timestep
timesteps = self.timesteps if timesteps is None else timesteps

if self._progress is None:
Expand Down
45 changes: 45 additions & 0 deletions skrl/utils/huggingface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from skrl import logger, __version__


def download_model_from_huggingface(repo_id: str, filename: str = "agent.pt") -> str:
"""Download a model from Hugging Face Hub
:param repo_id: Hugging Face user or organization name and a repo name separated by a ``/``
:type repo_id: str
:param filename: The name of the model file in the repo (default: ``"agent.pt"``)
:type filename: str, optional
:raises ImportError: The Hugging Face Hub package (huggingface-hub) is not installed
:raises huggingface_hub.utils._errors.HfHubHTTPError: Any HTTP error raised in Hugging Face Hub
:return: Local path of file or if networking is off, last version of file cached on disk
:rtype: str
Example::
# download trained agent from the skrl organization (https://huggingface.co/skrl)
>>> from skrl.utils.huggingface import download_model_from_huggingface
>>> download_model_from_huggingface("skrl/OmniIsaacGymEnvs-Cartpole-PPO")
'/home/user/.cache/huggingface/hub/models--skrl--OmniIsaacGymEnvs-Cartpole-PPO/snapshots/892e629903de6bf3ef102ae760406a5dd0f6f873/agent.pt'
# download model (e.g. "policy.pth") from another user/organization (e.g. "org/ddpg-Pendulum-v1")
>>> from skrl.utils.huggingface import download_model_from_huggingface
>>> download_model_from_huggingface("org/ddpg-Pendulum-v1", "policy.pth")
'/home/user/.cache/huggingface/hub/models--org--ddpg-Pendulum-v1/snapshots/b44ee96f93ff2e296156b002a2ca4646e197ba32/policy.pth'
"""
try:
import huggingface_hub
except ImportError:
logger.error("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it")
huggingface_hub = None

if huggingface_hub is None:
raise ImportError("Hugging Face Hub package is not installed. Use 'pip install huggingface-hub' to install it")

# download and cache the model from Hugging Face Hub
downloaded_model_file = huggingface_hub.hf_hub_download(repo_id=repo_id,
filename=filename,
library_name="skrl",
library_version=__version__)

return downloaded_model_file
2 changes: 1 addition & 1 deletion skrl/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.9.0
0.9.1
59 changes: 59 additions & 0 deletions tests/test_agents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import pytest
import warnings
import hypothesis
import hypothesis.strategies as st

import torch

from skrl.agents.torch import Agent

from skrl.agents.torch.a2c import A2C
from skrl.agents.torch.amp import AMP
from skrl.agents.torch.cem import CEM
from skrl.agents.torch.ddpg import DDPG
from skrl.agents.torch.dqn import DQN, DDQN
from skrl.agents.torch.ppo import PPO
from skrl.agents.torch.q_learning import Q_LEARNING
from skrl.agents.torch.sac import SAC
from skrl.agents.torch.sarsa import SARSA
from skrl.agents.torch.td3 import TD3
from skrl.agents.torch.trpo import TRPO

from .utils import DummyModel


@pytest.fixture
def classes_and_kwargs():
return [(A2C, {"models": {"policy": DummyModel()}}),
(AMP, {"models": {"policy": DummyModel()}}),
(CEM, {"models": {"policy": DummyModel()}}),
(DDPG, {"models": {"policy": DummyModel()}}),
(DQN, {"models": {"policy": DummyModel()}}),
(DDQN, {"models": {"policy": DummyModel()}}),
(PPO, {"models": {"policy": DummyModel()}}),
(Q_LEARNING, {"models": {"policy": DummyModel()}}),
(SAC, {"models": {"policy": DummyModel()}}),
(SARSA, {"models": {"policy": DummyModel()}}),
(TD3, {"models": {"policy": DummyModel()}}),
(TRPO, {"models": {"policy": DummyModel()}})]


def test_agent(capsys, classes_and_kwargs):
for klass, kwargs in classes_and_kwargs:
cfg = {"learning_starts": 1,
"experiment": {"write_interval": 0}}
agent: Agent = klass(cfg=cfg, **kwargs)

agent.init()
agent.pre_interaction(timestep=0, timesteps=1)
# agent.act(None, timestep=0, timestesps=1)
agent.record_transition(states=torch.tensor([]),
actions=torch.tensor([]),
rewards=torch.tensor([]),
next_states=torch.tensor([]),
terminated=torch.tensor([]),
truncated=torch.tensor([]),
infos={},
timestep=0,
timesteps=1)
agent.post_interaction(timestep=0, timesteps=1)
36 changes: 36 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest
import warnings
import hypothesis
import hypothesis.strategies as st

import torch

from skrl.envs.torch import Wrapper

from skrl.envs.torch import wrap_env

from .utils import DummyEnv


@pytest.fixture
def classes_and_kwargs():
return []


@pytest.mark.parametrize("wrapper", ["gym", "gymnasium", "dm", "robosuite", \
"isaacgym-preview2", "isaacgym-preview3", "isaacgym-preview4", "omniverse-isaacgym"])
def test_wrap_env(capsys, classes_and_kwargs, wrapper):
env = DummyEnv(num_envs=1)

try:
env: Wrapper = wrap_env(env=env, wrapper=wrapper)
except ValueError as e:
warnings.warn(f"{e}. This test will be skipped for '{wrapper}'")
except ModuleNotFoundError as e:
warnings.warn(f"{e}. The '{wrapper}' wrapper module is not found. This test will be skipped")

env.observation_space
env.action_space
env.state_space
env.num_envs
env.device
Loading

0 comments on commit fbb19d9

Please sign in to comment.