Skip to content

Commit

Permalink
Merge pull request #112 from Toni-SM/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
Toni-SM committed Aug 16, 2023
2 parents 57e7286 + 9a56ceb commit 8b875f8
Show file tree
Hide file tree
Showing 23 changed files with 321 additions and 55 deletions.
38 changes: 31 additions & 7 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,34 @@

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [1.0.0] - 2023-08-16

Transition from pre-release versions (`1.0.0-rc.1` and`1.0.0-rc.2`) to a stable version.

This release also announces the publication of the **skrl** paper in the Journal of Machine Learning Research (JMLR): https://www.jmlr.org/papers/v24/23-0112.html

Summary of the most relevant features:
- JAX support
- New documentation theme and structure
- Multi-agent Reinforcement Learning (MARL)

## [1.0.0-rc.2] - 2023-08-11
### Added
- Get truncation from `time_outs` info in Isaac Gym, Isaac Orbit and Omniverse Isaac Gym environments
- Time-limit (truncation) boostrapping in on-policy actor-critic agents
- Model instantiators `initial_log_std` parameter to set the log standard deviation's initial value

### Changed (breaking changes)
- Structure environment loaders and wrappers file hierarchy coherently
Import statements now follow the next convention:
- Wrappers (e.g.):
- `from skrl.envs.wrappers.torch import wrap_env`
- `from skrl.envs.wrappers.jax import wrap_env`
- Loaders (e.g.):
- `from skrl.envs.loaders.torch import load_omniverse_isaacgym_env`
- `from skrl.envs.loaders.jax import load_omniverse_isaacgym_env`

### Changed
- Structure environment loaders and wrappers file hierarchy coherently [**breaking change**]
- Drop support for versions prior to PyTorch 1.9 (1.8.0 and 1.8.1)

## [1.0.0-rc.1] - 2023-07-25
Expand Down Expand Up @@ -66,7 +86,7 @@ and Omniverse Isaac Gym environments when they are loaded
### Added
- Support for Farama Gymnasium interface
- Wrapper for robosuite environments
- Weights & Biases integration (by @juhannc)
- Weights & Biases integration
- Set the running mode (training or evaluation) of the agents
- Allow clipping the gradient norm for DDPG, TD3 and SAC agents
- Initialize model biases
Expand All @@ -75,9 +95,11 @@ and Omniverse Isaac Gym environments when they are loaded
- Farama Shimmy and robosuite examples
- KUKA LBR iiwa real-world example

### Changed (breaking changes)
- Forward model inputs as a Python dictionary
- Returns a Python dictionary with extra output values in model calls

### Changed
- Forward model inputs as a Python dictionary [**breaking change**]
- Returns a Python dictionary with extra output values in model calls [**breaking change**]
- Adopt the implementation of `terminated` and `truncated` over `done` for all environments

### Fixed
Expand All @@ -98,7 +120,7 @@ to allow storing samples in memories during evaluation
- Gaussian model mixin
- Support for creating shared models
- Parameter `role` to model methods
- Wrapper compatibility with the new OpenAI Gym environment API (by @juhannc)
- Wrapper compatibility with the new OpenAI Gym environment API
- Internal library colored logger
- Migrate checkpoints/models from other RL libraries to skrl models/agents
- Configuration parameter `store_separately` to agent configuration dict
Expand All @@ -107,11 +129,13 @@ to allow storing samples in memories during evaluation
- Benchmark results for Isaac Gym and Omniverse Isaac Gym on the GitHub discussion page
- Franka Emika real-world example

### Changed (breaking changes)
- Models implementation as Python mixin

### Changed
- Models implementation as Python mixin [**breaking change**]
- Multivariate Gaussian model (`GaussianModel` until 0.7.0) to `MultivariateGaussianMixin`
- Trainer's `cfg` parameter position and default values
- Show training/evaluation display progress using `tqdm` (by @juhannc)
- Show training/evaluation display progress using `tqdm`
- Update Isaac Gym and Omniverse Isaac Gym examples

### Fixed
Expand Down
18 changes: 12 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

<br>
<p align="center">
<a href="https://skrl.readthedocs.io">
<img width="300rem" src="https://raw.githubusercontent.com/Toni-SM/skrl/main/docs/source/_static/data/logo-light-mode.png">
</a>
</p>
<h2 align="center" style="border-bottom: 0 !important;">SKRL - Reinforcement Learning library</h2>
<br>
Expand All @@ -21,7 +23,7 @@

### Please, visit the documentation for usage details and examples

https://skrl.readthedocs.io/en/latest/
<strong>https://skrl.readthedocs.io</strong>

<br>

Expand All @@ -34,10 +36,14 @@ https://skrl.readthedocs.io/en/latest/
To cite this library in publications, please use the following reference:

```bibtex
@article{serrano2022skrl,
title={skrl: Modular and Flexible Library for Reinforcement Learning},
author={Serrano-Mu{\~n}oz, Antonio and Arana-Arexolaleiba, Nestor and Chrysostomou, Dimitrios and B{\o}gh, Simon},
journal={arXiv preprint arXiv:2202.03825},
year={2022}
@article{serrano2023skrl,
author = {Antonio Serrano-Muñoz and Dimitrios Chrysostomou and Simon Bøgh and Nestor Arana-Arexolaleiba},
title = {skrl: Modular and Flexible Library for Reinforcement Learning},
journal = {Journal of Machine Learning Research},
year = {2023},
volume = {24},
number = {254},
pages = {1--9},
url = {http://jmlr.org/papers/v24/23-0112.html}
}
```
3 changes: 3 additions & 0 deletions docs/source/api/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,6 @@ In addition, you will be able to :doc:`wrap single-agent <envs/wrapping>` and :d
* - robosuite
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Shimmy
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\blacksquare`
19 changes: 19 additions & 0 deletions docs/source/api/envs/wrapping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Wrapping (single-agent)
This library works with a common API to interact with the following RL environments:

* OpenAI `Gym <https://www.gymlibrary.dev>`_ / Farama `Gymnasium <https://gymnasium.farama.org/>`_ (single and vectorized environments)
* `Farama Shimmy <https://shimmy.farama.org/>`_
* `DeepMind <https://github.com/deepmind/dm_env>`_
* `robosuite <https://robosuite.ai/>`_
* `NVIDIA Isaac Gym <https://developer.nvidia.com/isaac-gym>`_ (preview 2, 3 and 4)
Expand Down Expand Up @@ -265,6 +266,24 @@ Usage
:start-after: [jax-start-gymnasium-vectorized]
:end-before: [jax-end-gymnasium-vectorized]

.. tab:: Shimmy

.. tabs::

.. group-tab:: |_4| |pytorch| |_4|

.. literalinclude:: ../../snippets/wrapping.py
:language: python
:start-after: [pytorch-start-shimmy]
:end-before: [pytorch-end-shimmy]

.. group-tab:: |_4| |jax| |_4|

.. literalinclude:: ../../snippets/wrapping.py
:language: python
:start-after: [jax-start-shimmy]
:end-before: [jax-end-shimmy]

.. tab:: DeepMind

.. tabs::
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
if skrl.__version__ != "unknown":
release = version = skrl.__version__
else:
release = version = "1.0.0-rc.2"
release = version = "1.0.0"

master_doc = "index"

Expand Down
17 changes: 16 additions & 1 deletion docs/source/examples/isaacorbit/jax_ant_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,8 @@ def __call__(self, inputs, role):
cfg["entropy_loss_scale"] = 0.0
cfg["value_loss_scale"] = 1.0
cfg["kl_threshold"] = 0
cfg["rewards_shaper"] = lambda rewards, timestep, timesteps: rewards * 0.01
cfg["rewards_shaper"] = lambda rewards, *args, **kwargs: rewards * 0.1
cfg["time_limit_bootstrap"] = True
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
Expand All @@ -139,3 +140,17 @@ def __call__(self, inputs, role):

# start training
trainer.train()


# # ---------------------------------------------------------
# # comment the code above: `trainer.train()`, and...
# # uncomment the following lines to evaluate a trained agent
# # ---------------------------------------------------------
# from skrl.utils.huggingface import download_model_from_huggingface

# # download the trained agent's checkpoint from Hugging Face Hub and load it
# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Ant-v0-PPO", filename="agent.pickle")
# agent.load(path)

# # start evaluation
# trainer.eval()
17 changes: 16 additions & 1 deletion docs/source/examples/isaacorbit/jax_cartpole_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __call__(self, inputs, role):
x = nn.elu(nn.Dense(32)(inputs["states"]))
x = nn.elu(nn.Dense(32)(x))
x = nn.Dense(self.num_actions)(x)
log_std = self.param("log_std", lambda _: jnp.zeros(self.num_actions))
log_std = self.param("log_std", lambda _: jnp.ones(self.num_actions))
return x, log_std, {}

class Value(DeterministicMixin, Model):
Expand Down Expand Up @@ -114,6 +114,7 @@ def __call__(self, inputs, role):
cfg["value_loss_scale"] = 2.0
cfg["kl_threshold"] = 0
cfg["rewards_shaper"] = None
cfg["time_limit_bootstrap"] = True
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
Expand All @@ -137,3 +138,17 @@ def __call__(self, inputs, role):

# start training
trainer.train()


# # ---------------------------------------------------------
# # comment the code above: `trainer.train()`, and...
# # uncomment the following lines to evaluate a trained agent
# # ---------------------------------------------------------
# from skrl.utils.huggingface import download_model_from_huggingface

# # download the trained agent's checkpoint from Hugging Face Hub and load it
# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Cartpole-v0-PPO", filename="agent.pickle")
# agent.load(path)

# # start evaluation
# trainer.eval()
25 changes: 20 additions & 5 deletions docs/source/examples/isaacorbit/jax_humanoid_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
# define models (stochastic and deterministic models) using mixins
class Policy(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device=None, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs):
clip_log_std=True, min_log_std=-5, max_log_std=2, reduction="sum", **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)

Expand All @@ -53,7 +53,7 @@ def __call__(self, inputs, role):
x = nn.elu(nn.Dense(100)(x))
x = nn.Dense(self.num_actions)(x)
log_std = self.param("log_std", lambda _: jnp.zeros(self.num_actions))
return x, log_std, {}
return nn.tanh(x), log_std, {}

class Value(DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs):
Expand Down Expand Up @@ -105,17 +105,18 @@ def __call__(self, inputs, role):
cfg["lambda"] = 0.95
cfg["learning_rate"] = 3e-4
cfg["learning_rate_scheduler"] = KLAdaptiveRL
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.008}
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01}
cfg["random_timesteps"] = 0
cfg["learning_starts"] = 0
cfg["grad_norm_clip"] = 1.0
cfg["ratio_clip"] = 0.2
cfg["value_clip"] = 0.2
cfg["clip_predicted_values"] = True
cfg["entropy_loss_scale"] = 0.0
cfg["value_loss_scale"] = 1.0
cfg["value_loss_scale"] = 4.0
cfg["kl_threshold"] = 0
cfg["rewards_shaper"] = lambda rewards, timestep, timesteps: rewards * 0.01
cfg["rewards_shaper"] = lambda rewards, *args, **kwargs: rewards * 0.01
cfg["time_limit_bootstrap"] = False
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
Expand All @@ -139,3 +140,17 @@ def __call__(self, inputs, role):

# start training
trainer.train()


# # ---------------------------------------------------------
# # comment the code above: `trainer.train()`, and...
# # uncomment the following lines to evaluate a trained agent
# # ---------------------------------------------------------
# from skrl.utils.huggingface import download_model_from_huggingface

# # download the trained agent's checkpoint from Hugging Face Hub and load it
# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Humanoid-v0-PPO", filename="agent.pickle")
# agent.load(path)

# # start evaluation
# trainer.eval()
18 changes: 16 additions & 2 deletions docs/source/examples/isaacorbit/jax_lift_franka_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def __call__(self, inputs, role):
cfg["value_preprocessor"] = RunningStandardScaler
cfg["value_preprocessor_kwargs"] = {"size": 1, "device": device}
# logging to TensorBoard and write checkpoints (in timesteps)
cfg["experiment"]["write_interval"] = 800
cfg["experiment"]["checkpoint_interval"] = 8000
cfg["experiment"]["write_interval"] = 336
cfg["experiment"]["checkpoint_interval"] = 3360
cfg["experiment"]["directory"] = "runs/jax/Isaac-Lift-Franka-v0"

agent = PPO(models=models,
Expand All @@ -140,3 +140,17 @@ def __call__(self, inputs, role):

# start training
trainer.train()


# # ---------------------------------------------------------
# # comment the code above: `trainer.train()`, and...
# # uncomment the following lines to evaluate a trained agent
# # ---------------------------------------------------------
# from skrl.utils.huggingface import download_model_from_huggingface

# # download the trained agent's checkpoint from Hugging Face Hub and load it
# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Lift-Franka-v0-PPO", filename="agent.pickle")
# agent.load(path)

# # start evaluation
# trainer.eval()
21 changes: 18 additions & 3 deletions docs/source/examples/isaacorbit/jax_reach_franka_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def __call__(self, inputs, role):
x = nn.elu(nn.Dense(128)(x))
x = nn.elu(nn.Dense(64)(x))
x = nn.Dense(self.num_actions)(x)
log_std = self.param("log_std", lambda _: jnp.zeros(self.num_actions))
return x, log_std, {}
log_std = self.param("log_std", lambda _: 0.5 * jnp.ones(self.num_actions))
return nn.tanh(x), log_std, {}

class Value(DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs):
Expand Down Expand Up @@ -105,7 +105,7 @@ def __call__(self, inputs, role):
cfg["lambda"] = 0.95
cfg["learning_rate"] = 3e-4
cfg["learning_rate_scheduler"] = KLAdaptiveRL
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.008}
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01}
cfg["random_timesteps"] = 0
cfg["learning_starts"] = 0
cfg["grad_norm_clip"] = 1.0
Expand All @@ -116,6 +116,7 @@ def __call__(self, inputs, role):
cfg["value_loss_scale"] = 2.0
cfg["kl_threshold"] = 0
cfg["rewards_shaper"] = None
cfg["time_limit_bootstrap"] = False
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
Expand All @@ -139,3 +140,17 @@ def __call__(self, inputs, role):

# start training
trainer.train()


# # ---------------------------------------------------------
# # comment the code above: `trainer.train()`, and...
# # uncomment the following lines to evaluate a trained agent
# # ---------------------------------------------------------
# from skrl.utils.huggingface import download_model_from_huggingface

# # download the trained agent's checkpoint from Hugging Face Hub and load it
# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Reach-Franka-v0-PPO", filename="agent.pickle")
# agent.load(path)

# # start evaluation
# trainer.eval()
3 changes: 2 additions & 1 deletion docs/source/examples/isaacorbit/jax_velocity_anymal_c_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,11 @@ def __call__(self, inputs, role):
cfg["ratio_clip"] = 0.2
cfg["value_clip"] = 0.2
cfg["clip_predicted_values"] = True
cfg["entropy_loss_scale"] = 0.01
cfg["entropy_loss_scale"] = 0.0
cfg["value_loss_scale"] = 1.0
cfg["kl_threshold"] = 0
cfg["rewards_shaper"] = None
cfg["time_limit_bootstrap"] = False
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
Expand Down
3 changes: 2 additions & 1 deletion docs/source/examples/isaacorbit/torch_ant_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def compute(self, inputs, role):
cfg["entropy_loss_scale"] = 0.0
cfg["value_loss_scale"] = 1.0
cfg["kl_threshold"] = 0
cfg["rewards_shaper"] = lambda rewards, timestep, timesteps: rewards * 0.01
cfg["rewards_shaper"] = lambda rewards, *args, **kwargs: rewards * 0.1
cfg["time_limit_bootstrap"] = True
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
Expand Down
Loading

0 comments on commit 8b875f8

Please sign in to comment.