diff --git a/CHANGELOG.md b/CHANGELOG.md
index 1d77eb10..745e2b14 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,14 +2,34 @@
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
+## [1.0.0] - 2023-08-16
+
+Transition from pre-release versions (`1.0.0-rc.1` and`1.0.0-rc.2`) to a stable version.
+
+This release also announces the publication of the **skrl** paper in the Journal of Machine Learning Research (JMLR): https://www.jmlr.org/papers/v24/23-0112.html
+
+Summary of the most relevant features:
+- JAX support
+- New documentation theme and structure
+- Multi-agent Reinforcement Learning (MARL)
+
## [1.0.0-rc.2] - 2023-08-11
### Added
- Get truncation from `time_outs` info in Isaac Gym, Isaac Orbit and Omniverse Isaac Gym environments
- Time-limit (truncation) boostrapping in on-policy actor-critic agents
- Model instantiators `initial_log_std` parameter to set the log standard deviation's initial value
+### Changed (breaking changes)
+- Structure environment loaders and wrappers file hierarchy coherently
+ Import statements now follow the next convention:
+ - Wrappers (e.g.):
+ - `from skrl.envs.wrappers.torch import wrap_env`
+ - `from skrl.envs.wrappers.jax import wrap_env`
+ - Loaders (e.g.):
+ - `from skrl.envs.loaders.torch import load_omniverse_isaacgym_env`
+ - `from skrl.envs.loaders.jax import load_omniverse_isaacgym_env`
+
### Changed
-- Structure environment loaders and wrappers file hierarchy coherently [**breaking change**]
- Drop support for versions prior to PyTorch 1.9 (1.8.0 and 1.8.1)
## [1.0.0-rc.1] - 2023-07-25
@@ -66,7 +86,7 @@ and Omniverse Isaac Gym environments when they are loaded
### Added
- Support for Farama Gymnasium interface
- Wrapper for robosuite environments
-- Weights & Biases integration (by @juhannc)
+- Weights & Biases integration
- Set the running mode (training or evaluation) of the agents
- Allow clipping the gradient norm for DDPG, TD3 and SAC agents
- Initialize model biases
@@ -75,9 +95,11 @@ and Omniverse Isaac Gym environments when they are loaded
- Farama Shimmy and robosuite examples
- KUKA LBR iiwa real-world example
+### Changed (breaking changes)
+- Forward model inputs as a Python dictionary
+- Returns a Python dictionary with extra output values in model calls
+
### Changed
-- Forward model inputs as a Python dictionary [**breaking change**]
-- Returns a Python dictionary with extra output values in model calls [**breaking change**]
- Adopt the implementation of `terminated` and `truncated` over `done` for all environments
### Fixed
@@ -98,7 +120,7 @@ to allow storing samples in memories during evaluation
- Gaussian model mixin
- Support for creating shared models
- Parameter `role` to model methods
-- Wrapper compatibility with the new OpenAI Gym environment API (by @juhannc)
+- Wrapper compatibility with the new OpenAI Gym environment API
- Internal library colored logger
- Migrate checkpoints/models from other RL libraries to skrl models/agents
- Configuration parameter `store_separately` to agent configuration dict
@@ -107,11 +129,13 @@ to allow storing samples in memories during evaluation
- Benchmark results for Isaac Gym and Omniverse Isaac Gym on the GitHub discussion page
- Franka Emika real-world example
+### Changed (breaking changes)
+- Models implementation as Python mixin
+
### Changed
-- Models implementation as Python mixin [**breaking change**]
- Multivariate Gaussian model (`GaussianModel` until 0.7.0) to `MultivariateGaussianMixin`
- Trainer's `cfg` parameter position and default values
-- Show training/evaluation display progress using `tqdm` (by @juhannc)
+- Show training/evaluation display progress using `tqdm`
- Update Isaac Gym and Omniverse Isaac Gym examples
### Fixed
diff --git a/README.md b/README.md
index 49e63854..65b7b391 100644
--- a/README.md
+++ b/README.md
@@ -10,7 +10,9 @@
+
+
SKRL - Reinforcement Learning library
@@ -21,7 +23,7 @@
### Please, visit the documentation for usage details and examples
-https://skrl.readthedocs.io/en/latest/
+https://skrl.readthedocs.io
@@ -34,10 +36,14 @@ https://skrl.readthedocs.io/en/latest/
To cite this library in publications, please use the following reference:
```bibtex
-@article{serrano2022skrl,
- title={skrl: Modular and Flexible Library for Reinforcement Learning},
- author={Serrano-Mu{\~n}oz, Antonio and Arana-Arexolaleiba, Nestor and Chrysostomou, Dimitrios and B{\o}gh, Simon},
- journal={arXiv preprint arXiv:2202.03825},
- year={2022}
+@article{serrano2023skrl,
+ author = {Antonio Serrano-Muñoz and Dimitrios Chrysostomou and Simon Bøgh and Nestor Arana-Arexolaleiba},
+ title = {skrl: Modular and Flexible Library for Reinforcement Learning},
+ journal = {Journal of Machine Learning Research},
+ year = {2023},
+ volume = {24},
+ number = {254},
+ pages = {1--9},
+ url = {http://jmlr.org/papers/v24/23-0112.html}
}
```
diff --git a/docs/source/api/envs.rst b/docs/source/api/envs.rst
index 55e23c37..3b3091da 100644
--- a/docs/source/api/envs.rst
+++ b/docs/source/api/envs.rst
@@ -69,3 +69,6 @@ In addition, you will be able to :doc:`wrap single-agent ` and :d
* - robosuite
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Shimmy
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\blacksquare`
diff --git a/docs/source/api/envs/wrapping.rst b/docs/source/api/envs/wrapping.rst
index a4454cd3..a725a7e9 100644
--- a/docs/source/api/envs/wrapping.rst
+++ b/docs/source/api/envs/wrapping.rst
@@ -10,6 +10,7 @@ Wrapping (single-agent)
This library works with a common API to interact with the following RL environments:
* OpenAI `Gym `_ / Farama `Gymnasium `_ (single and vectorized environments)
+* `Farama Shimmy `_
* `DeepMind `_
* `robosuite `_
* `NVIDIA Isaac Gym `_ (preview 2, 3 and 4)
@@ -265,6 +266,24 @@ Usage
:start-after: [jax-start-gymnasium-vectorized]
:end-before: [jax-end-gymnasium-vectorized]
+ .. tab:: Shimmy
+
+ .. tabs::
+
+ .. group-tab:: |_4| |pytorch| |_4|
+
+ .. literalinclude:: ../../snippets/wrapping.py
+ :language: python
+ :start-after: [pytorch-start-shimmy]
+ :end-before: [pytorch-end-shimmy]
+
+ .. group-tab:: |_4| |jax| |_4|
+
+ .. literalinclude:: ../../snippets/wrapping.py
+ :language: python
+ :start-after: [jax-start-shimmy]
+ :end-before: [jax-end-shimmy]
+
.. tab:: DeepMind
.. tabs::
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 8882fe6d..42f37f8a 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -16,7 +16,7 @@
if skrl.__version__ != "unknown":
release = version = skrl.__version__
else:
- release = version = "1.0.0-rc.2"
+ release = version = "1.0.0"
master_doc = "index"
diff --git a/docs/source/examples/isaacorbit/jax_ant_ppo.py b/docs/source/examples/isaacorbit/jax_ant_ppo.py
index 5ad99973..ba0cc9f3 100644
--- a/docs/source/examples/isaacorbit/jax_ant_ppo.py
+++ b/docs/source/examples/isaacorbit/jax_ant_ppo.py
@@ -115,7 +115,8 @@ def __call__(self, inputs, role):
cfg["entropy_loss_scale"] = 0.0
cfg["value_loss_scale"] = 1.0
cfg["kl_threshold"] = 0
-cfg["rewards_shaper"] = lambda rewards, timestep, timesteps: rewards * 0.01
+cfg["rewards_shaper"] = lambda rewards, *args, **kwargs: rewards * 0.1
+cfg["time_limit_bootstrap"] = True
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
@@ -139,3 +140,17 @@ def __call__(self, inputs, role):
# start training
trainer.train()
+
+
+# # ---------------------------------------------------------
+# # comment the code above: `trainer.train()`, and...
+# # uncomment the following lines to evaluate a trained agent
+# # ---------------------------------------------------------
+# from skrl.utils.huggingface import download_model_from_huggingface
+
+# # download the trained agent's checkpoint from Hugging Face Hub and load it
+# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Ant-v0-PPO", filename="agent.pickle")
+# agent.load(path)
+
+# # start evaluation
+# trainer.eval()
diff --git a/docs/source/examples/isaacorbit/jax_cartpole_ppo.py b/docs/source/examples/isaacorbit/jax_cartpole_ppo.py
index fedfb602..d6d5ef1b 100644
--- a/docs/source/examples/isaacorbit/jax_cartpole_ppo.py
+++ b/docs/source/examples/isaacorbit/jax_cartpole_ppo.py
@@ -51,7 +51,7 @@ def __call__(self, inputs, role):
x = nn.elu(nn.Dense(32)(inputs["states"]))
x = nn.elu(nn.Dense(32)(x))
x = nn.Dense(self.num_actions)(x)
- log_std = self.param("log_std", lambda _: jnp.zeros(self.num_actions))
+ log_std = self.param("log_std", lambda _: jnp.ones(self.num_actions))
return x, log_std, {}
class Value(DeterministicMixin, Model):
@@ -114,6 +114,7 @@ def __call__(self, inputs, role):
cfg["value_loss_scale"] = 2.0
cfg["kl_threshold"] = 0
cfg["rewards_shaper"] = None
+cfg["time_limit_bootstrap"] = True
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
@@ -137,3 +138,17 @@ def __call__(self, inputs, role):
# start training
trainer.train()
+
+
+# # ---------------------------------------------------------
+# # comment the code above: `trainer.train()`, and...
+# # uncomment the following lines to evaluate a trained agent
+# # ---------------------------------------------------------
+# from skrl.utils.huggingface import download_model_from_huggingface
+
+# # download the trained agent's checkpoint from Hugging Face Hub and load it
+# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Cartpole-v0-PPO", filename="agent.pickle")
+# agent.load(path)
+
+# # start evaluation
+# trainer.eval()
diff --git a/docs/source/examples/isaacorbit/jax_humanoid_ppo.py b/docs/source/examples/isaacorbit/jax_humanoid_ppo.py
index 6405b970..dafdb04b 100644
--- a/docs/source/examples/isaacorbit/jax_humanoid_ppo.py
+++ b/docs/source/examples/isaacorbit/jax_humanoid_ppo.py
@@ -39,7 +39,7 @@
# define models (stochastic and deterministic models) using mixins
class Policy(GaussianMixin, Model):
def __init__(self, observation_space, action_space, device=None, clip_actions=False,
- clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs):
+ clip_log_std=True, min_log_std=-5, max_log_std=2, reduction="sum", **kwargs):
Model.__init__(self, observation_space, action_space, device, **kwargs)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
@@ -53,7 +53,7 @@ def __call__(self, inputs, role):
x = nn.elu(nn.Dense(100)(x))
x = nn.Dense(self.num_actions)(x)
log_std = self.param("log_std", lambda _: jnp.zeros(self.num_actions))
- return x, log_std, {}
+ return nn.tanh(x), log_std, {}
class Value(DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs):
@@ -105,7 +105,7 @@ def __call__(self, inputs, role):
cfg["lambda"] = 0.95
cfg["learning_rate"] = 3e-4
cfg["learning_rate_scheduler"] = KLAdaptiveRL
-cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.008}
+cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01}
cfg["random_timesteps"] = 0
cfg["learning_starts"] = 0
cfg["grad_norm_clip"] = 1.0
@@ -113,9 +113,10 @@ def __call__(self, inputs, role):
cfg["value_clip"] = 0.2
cfg["clip_predicted_values"] = True
cfg["entropy_loss_scale"] = 0.0
-cfg["value_loss_scale"] = 1.0
+cfg["value_loss_scale"] = 4.0
cfg["kl_threshold"] = 0
-cfg["rewards_shaper"] = lambda rewards, timestep, timesteps: rewards * 0.01
+cfg["rewards_shaper"] = lambda rewards, *args, **kwargs: rewards * 0.01
+cfg["time_limit_bootstrap"] = False
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
@@ -139,3 +140,17 @@ def __call__(self, inputs, role):
# start training
trainer.train()
+
+
+# # ---------------------------------------------------------
+# # comment the code above: `trainer.train()`, and...
+# # uncomment the following lines to evaluate a trained agent
+# # ---------------------------------------------------------
+# from skrl.utils.huggingface import download_model_from_huggingface
+
+# # download the trained agent's checkpoint from Hugging Face Hub and load it
+# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Humanoid-v0-PPO", filename="agent.pickle")
+# agent.load(path)
+
+# # start evaluation
+# trainer.eval()
diff --git a/docs/source/examples/isaacorbit/jax_lift_franka_ppo.py b/docs/source/examples/isaacorbit/jax_lift_franka_ppo.py
index ca4945df..b2d24b69 100644
--- a/docs/source/examples/isaacorbit/jax_lift_franka_ppo.py
+++ b/docs/source/examples/isaacorbit/jax_lift_franka_ppo.py
@@ -122,8 +122,8 @@ def __call__(self, inputs, role):
cfg["value_preprocessor"] = RunningStandardScaler
cfg["value_preprocessor_kwargs"] = {"size": 1, "device": device}
# logging to TensorBoard and write checkpoints (in timesteps)
-cfg["experiment"]["write_interval"] = 800
-cfg["experiment"]["checkpoint_interval"] = 8000
+cfg["experiment"]["write_interval"] = 336
+cfg["experiment"]["checkpoint_interval"] = 3360
cfg["experiment"]["directory"] = "runs/jax/Isaac-Lift-Franka-v0"
agent = PPO(models=models,
@@ -140,3 +140,17 @@ def __call__(self, inputs, role):
# start training
trainer.train()
+
+
+# # ---------------------------------------------------------
+# # comment the code above: `trainer.train()`, and...
+# # uncomment the following lines to evaluate a trained agent
+# # ---------------------------------------------------------
+# from skrl.utils.huggingface import download_model_from_huggingface
+
+# # download the trained agent's checkpoint from Hugging Face Hub and load it
+# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Lift-Franka-v0-PPO", filename="agent.pickle")
+# agent.load(path)
+
+# # start evaluation
+# trainer.eval()
diff --git a/docs/source/examples/isaacorbit/jax_reach_franka_ppo.py b/docs/source/examples/isaacorbit/jax_reach_franka_ppo.py
index 210c4ab9..2b01af30 100644
--- a/docs/source/examples/isaacorbit/jax_reach_franka_ppo.py
+++ b/docs/source/examples/isaacorbit/jax_reach_franka_ppo.py
@@ -52,8 +52,8 @@ def __call__(self, inputs, role):
x = nn.elu(nn.Dense(128)(x))
x = nn.elu(nn.Dense(64)(x))
x = nn.Dense(self.num_actions)(x)
- log_std = self.param("log_std", lambda _: jnp.zeros(self.num_actions))
- return x, log_std, {}
+ log_std = self.param("log_std", lambda _: 0.5 * jnp.ones(self.num_actions))
+ return nn.tanh(x), log_std, {}
class Value(DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs):
@@ -105,7 +105,7 @@ def __call__(self, inputs, role):
cfg["lambda"] = 0.95
cfg["learning_rate"] = 3e-4
cfg["learning_rate_scheduler"] = KLAdaptiveRL
-cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.008}
+cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01}
cfg["random_timesteps"] = 0
cfg["learning_starts"] = 0
cfg["grad_norm_clip"] = 1.0
@@ -116,6 +116,7 @@ def __call__(self, inputs, role):
cfg["value_loss_scale"] = 2.0
cfg["kl_threshold"] = 0
cfg["rewards_shaper"] = None
+cfg["time_limit_bootstrap"] = False
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
@@ -139,3 +140,17 @@ def __call__(self, inputs, role):
# start training
trainer.train()
+
+
+# # ---------------------------------------------------------
+# # comment the code above: `trainer.train()`, and...
+# # uncomment the following lines to evaluate a trained agent
+# # ---------------------------------------------------------
+# from skrl.utils.huggingface import download_model_from_huggingface
+
+# # download the trained agent's checkpoint from Hugging Face Hub and load it
+# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Reach-Franka-v0-PPO", filename="agent.pickle")
+# agent.load(path)
+
+# # start evaluation
+# trainer.eval()
diff --git a/docs/source/examples/isaacorbit/jax_velocity_anymal_c_ppo.py b/docs/source/examples/isaacorbit/jax_velocity_anymal_c_ppo.py
index 6711c59b..0238a6ec 100644
--- a/docs/source/examples/isaacorbit/jax_velocity_anymal_c_ppo.py
+++ b/docs/source/examples/isaacorbit/jax_velocity_anymal_c_ppo.py
@@ -112,10 +112,11 @@ def __call__(self, inputs, role):
cfg["ratio_clip"] = 0.2
cfg["value_clip"] = 0.2
cfg["clip_predicted_values"] = True
-cfg["entropy_loss_scale"] = 0.01
+cfg["entropy_loss_scale"] = 0.0
cfg["value_loss_scale"] = 1.0
cfg["kl_threshold"] = 0
cfg["rewards_shaper"] = None
+cfg["time_limit_bootstrap"] = False
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
diff --git a/docs/source/examples/isaacorbit/torch_ant_ppo.py b/docs/source/examples/isaacorbit/torch_ant_ppo.py
index 92e986a6..54f337a8 100644
--- a/docs/source/examples/isaacorbit/torch_ant_ppo.py
+++ b/docs/source/examples/isaacorbit/torch_ant_ppo.py
@@ -89,7 +89,8 @@ def compute(self, inputs, role):
cfg["entropy_loss_scale"] = 0.0
cfg["value_loss_scale"] = 1.0
cfg["kl_threshold"] = 0
-cfg["rewards_shaper"] = lambda rewards, timestep, timesteps: rewards * 0.01
+cfg["rewards_shaper"] = lambda rewards, *args, **kwargs: rewards * 0.1
+cfg["time_limit_bootstrap"] = True
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
diff --git a/docs/source/examples/isaacorbit/torch_cartpole_ppo.py b/docs/source/examples/isaacorbit/torch_cartpole_ppo.py
index 51da382e..cd0d3f78 100644
--- a/docs/source/examples/isaacorbit/torch_cartpole_ppo.py
+++ b/docs/source/examples/isaacorbit/torch_cartpole_ppo.py
@@ -31,7 +31,7 @@ def __init__(self, observation_space, action_space, device, clip_actions=False,
nn.ELU())
self.mean_layer = nn.Linear(32, self.num_actions)
- self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
+ self.log_std_parameter = nn.Parameter(torch.ones(self.num_actions))
self.value_layer = nn.Linear(32, 1)
@@ -88,6 +88,7 @@ def compute(self, inputs, role):
cfg["value_loss_scale"] = 2.0
cfg["kl_threshold"] = 0
cfg["rewards_shaper"] = None
+cfg["time_limit_bootstrap"] = True
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
diff --git a/docs/source/examples/isaacorbit/torch_humanoid_ppo.py b/docs/source/examples/isaacorbit/torch_humanoid_ppo.py
index 09f9b974..c3e8f876 100644
--- a/docs/source/examples/isaacorbit/torch_humanoid_ppo.py
+++ b/docs/source/examples/isaacorbit/torch_humanoid_ppo.py
@@ -20,7 +20,7 @@
# define shared model (stochastic and deterministic models) using mixins
class Shared(GaussianMixin, DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
- clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
+ clip_log_std=True, min_log_std=-5, max_log_std=2, reduction="sum"):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
DeterministicMixin.__init__(self, clip_actions)
@@ -45,7 +45,7 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ return torch.tanh(self.mean_layer(self.net(inputs["states"]))), self.log_std_parameter, {}
elif role == "value":
return self.value_layer(self.net(inputs["states"])), {}
@@ -79,7 +79,7 @@ def compute(self, inputs, role):
cfg["lambda"] = 0.95
cfg["learning_rate"] = 3e-4
cfg["learning_rate_scheduler"] = KLAdaptiveRL
-cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.008}
+cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01}
cfg["random_timesteps"] = 0
cfg["learning_starts"] = 0
cfg["grad_norm_clip"] = 1.0
@@ -87,9 +87,10 @@ def compute(self, inputs, role):
cfg["value_clip"] = 0.2
cfg["clip_predicted_values"] = True
cfg["entropy_loss_scale"] = 0.0
-cfg["value_loss_scale"] = 1.0
+cfg["value_loss_scale"] = 4.0
cfg["kl_threshold"] = 0
-cfg["rewards_shaper"] = lambda rewards, timestep, timesteps: rewards * 0.01
+cfg["rewards_shaper"] = lambda rewards, *args, **kwargs: rewards * 0.01
+cfg["time_limit_bootstrap"] = False
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
diff --git a/docs/source/examples/isaacorbit/torch_lift_franka_ppo.py b/docs/source/examples/isaacorbit/torch_lift_franka_ppo.py
index 5f3f3745..406184b9 100644
--- a/docs/source/examples/isaacorbit/torch_lift_franka_ppo.py
+++ b/docs/source/examples/isaacorbit/torch_lift_franka_ppo.py
@@ -96,8 +96,8 @@ def compute(self, inputs, role):
cfg["value_preprocessor"] = RunningStandardScaler
cfg["value_preprocessor_kwargs"] = {"size": 1, "device": device}
# logging to TensorBoard and write checkpoints (in timesteps)
-cfg["experiment"]["write_interval"] = 800
-cfg["experiment"]["checkpoint_interval"] = 8000
+cfg["experiment"]["write_interval"] = 336
+cfg["experiment"]["checkpoint_interval"] = 3360
cfg["experiment"]["directory"] = "runs/torch/Isaac-Lift-Franka-v0"
agent = PPO(models=models,
@@ -114,3 +114,17 @@ def compute(self, inputs, role):
# start training
trainer.train()
+
+
+# # ---------------------------------------------------------
+# # comment the code above: `trainer.train()`, and...
+# # uncomment the following lines to evaluate a trained agent
+# # ---------------------------------------------------------
+# from skrl.utils.huggingface import download_model_from_huggingface
+
+# # download the trained agent's checkpoint from Hugging Face Hub and load it
+# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Lift-Franka-v0-PPO", filename="agent.pt")
+# agent.load(path)
+
+# # start evaluation
+# trainer.eval()
diff --git a/docs/source/examples/isaacorbit/torch_reach_franka_ppo.py b/docs/source/examples/isaacorbit/torch_reach_franka_ppo.py
index 7bb8be7e..bc747c5a 100644
--- a/docs/source/examples/isaacorbit/torch_reach_franka_ppo.py
+++ b/docs/source/examples/isaacorbit/torch_reach_franka_ppo.py
@@ -33,7 +33,7 @@ def __init__(self, observation_space, action_space, device, clip_actions=False,
nn.ELU())
self.mean_layer = nn.Linear(64, self.num_actions)
- self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
+ self.log_std_parameter = nn.Parameter(0.5 * torch.ones(self.num_actions))
self.value_layer = nn.Linear(64, 1)
@@ -45,7 +45,7 @@ def act(self, inputs, role):
def compute(self, inputs, role):
if role == "policy":
- return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {}
+ return torch.tanh(self.mean_layer(self.net(inputs["states"]))), self.log_std_parameter, {}
elif role == "value":
return self.value_layer(self.net(inputs["states"])), {}
@@ -79,7 +79,7 @@ def compute(self, inputs, role):
cfg["lambda"] = 0.95
cfg["learning_rate"] = 3e-4
cfg["learning_rate_scheduler"] = KLAdaptiveRL
-cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.008}
+cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01}
cfg["random_timesteps"] = 0
cfg["learning_starts"] = 0
cfg["grad_norm_clip"] = 1.0
@@ -90,6 +90,7 @@ def compute(self, inputs, role):
cfg["value_loss_scale"] = 2.0
cfg["kl_threshold"] = 0
cfg["rewards_shaper"] = None
+cfg["time_limit_bootstrap"] = False
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
diff --git a/docs/source/examples/isaacorbit/torch_velocity_anymal_c_ppo.py b/docs/source/examples/isaacorbit/torch_velocity_anymal_c_ppo.py
index 4eaa7d9b..d5224a05 100644
--- a/docs/source/examples/isaacorbit/torch_velocity_anymal_c_ppo.py
+++ b/docs/source/examples/isaacorbit/torch_velocity_anymal_c_ppo.py
@@ -86,10 +86,11 @@ def compute(self, inputs, role):
cfg["ratio_clip"] = 0.2
cfg["value_clip"] = 0.2
cfg["clip_predicted_values"] = True
-cfg["entropy_loss_scale"] = 0.01
+cfg["entropy_loss_scale"] = 0.0
cfg["value_loss_scale"] = 1.0
cfg["kl_threshold"] = 0
cfg["rewards_shaper"] = None
+cfg["time_limit_bootstrap"] = False
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
diff --git a/docs/source/index.rst b/docs/source/index.rst
index f1cd8aec..5ab4e515 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -50,15 +50,19 @@ SKRL - Reinforcement Learning library (|version|)
| **Questions or discussions:** https://github.com/Toni-SM/skrl/discussions
|
-**Citing skrl:** To cite this library (created at `Mondragon Unibertsitatea `_) use the following reference to its `article `_: *"skrl: Modular and Flexible Library for Reinforcement Learning"*
+**Citing skrl:** To cite this library (created at Mondragon Unibertsitatea) use the following reference to its article: `skrl: Modular and Flexible Library for Reinforcement Learning `_.
.. code-block:: bibtex
- @article{serrano2022skrl,
- title={skrl: Modular and Flexible Library for Reinforcement Learning},
- author={Serrano-Mu{\~n}oz, Antonio and Arana-Arexolaleiba, Nestor and Chrysostomou, Dimitrios and B{\o}gh, Simon},
- journal={arXiv preprint arXiv:2202.03825},
- year={2022}
+ @article{serrano2023skrl,
+ author = {Antonio Serrano-Muñoz and Dimitrios Chrysostomou and Simon Bøgh and Nestor Arana-Arexolaleiba},
+ title = {skrl: Modular and Flexible Library for Reinforcement Learning},
+ journal = {Journal of Machine Learning Research},
+ year = {2023},
+ volume = {24},
+ number = {254},
+ pages = {1--9},
+ url = {http://jmlr.org/papers/v24/23-0112.html}
}
.. raw:: html
diff --git a/docs/source/intro/examples.rst b/docs/source/intro/examples.rst
index 6dac2e8e..2f16383e 100644
--- a/docs/source/intro/examples.rst
+++ b/docs/source/intro/examples.rst
@@ -98,12 +98,33 @@ Training/evaluation of an agent in `Gymnasium `_
- :download:`jax_gymnasium_cartpole_cem.py <../examples/gymnasium/jax_gymnasium_cartpole_cem.py>`
|br| :download:`jax_gymnasium_cartpole_dqn.py <../examples/gymnasium/jax_gymnasium_cartpole_dqn.py>`
-
+ * - FrozenLake
+ -
+ -
* - Pendulum
- :download:`jax_gymnasium_pendulum_ddpg.py <../examples/gymnasium/jax_gymnasium_pendulum_ddpg.py>`
|br| :download:`jax_gymnasium_pendulum_ppo.py <../examples/gymnasium/jax_gymnasium_pendulum_ppo.py>`
|br| :download:`jax_gymnasium_pendulum_sac.py <../examples/gymnasium/jax_gymnasium_pendulum_sac.py>`
|br| :download:`jax_gymnasium_pendulum_td3.py <../examples/gymnasium/jax_gymnasium_pendulum_td3.py>`
-
+ * - PendulumNoVel*
+ |br| (RNN / GRU / LSTM)
+ - |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ -
+ * - Taxi
+ -
+ -
.. note::
@@ -171,12 +192,33 @@ Training/evaluation of an agent in `Gymnasium `_
- :download:`jax_gym_cartpole_cem.py <../examples/gym/jax_gym_cartpole_cem.py>`
|br| :download:`jax_gym_cartpole_dqn.py <../examples/gym/jax_gym_cartpole_dqn.py>`
-
+ * - FrozenLake
+ -
+ -
* - Pendulum
- :download:`jax_gym_pendulum_ddpg.py <../examples/gym/jax_gym_pendulum_ddpg.py>`
|br| :download:`jax_gym_pendulum_ppo.py <../examples/gym/jax_gym_pendulum_ppo.py>`
|br| :download:`jax_gym_pendulum_sac.py <../examples/gym/jax_gym_pendulum_sac.py>`
|br| :download:`jax_gym_pendulum_td3.py <../examples/gym/jax_gym_pendulum_td3.py>`
-
+ * - PendulumNoVel*
+ |br| (RNN / GRU / LSTM)
+ - |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ |br|
+ -
+ * - Taxi
+ -
+ -
.. note::
@@ -235,9 +277,15 @@ Training/evaluation of an agent in `Gymnasium `_
* - CartPole
- :download:`jax_gymnasium_cartpole_vector_dqn.py <../examples/gymnasium/jax_gymnasium_cartpole_vector_dqn.py>`
-
+ * - FrozenLake
+ -
+ -
* - Pendulum
- :download:`jax_gymnasium_pendulum_vector_ddpg.py <../examples/gymnasium/jax_gymnasium_pendulum_vector_ddpg.py>`
-
+ * - Taxi
+ -
+ -
.. group-tab:: Gym
@@ -281,9 +329,15 @@ Training/evaluation of an agent in `Gymnasium `_
* - CartPole
- :download:`jax_gym_cartpole_vector_dqn.py <../examples/gym/jax_gym_cartpole_vector_dqn.py>`
-
+ * - FrozenLake
+ -
+ -
* - Pendulum
- :download:`jax_gym_pendulum_vector_ddpg.py <../examples/gym/jax_gym_pendulum_vector_ddpg.py>`
-
+ * - Taxi
+ -
+ -
.. raw:: html
@@ -835,7 +889,7 @@ The agent configuration is mapped, as far as possible, from the `Isaac Orbit con
- `IsaacOrbit-Isaac-Humanoid-v0-PPO `_
* - Isaac-Lift-Franka-v0
- :download:`torch_lift_franka_ppo.py <../examples/isaacorbit/torch_lift_franka_ppo.py>`
- -
+ - `IsaacOrbit-Isaac-Lift-Franka-v0-PPO `_
* - Isaac-Reach-Franka-v0
- :download:`torch_reach_franka_ppo.py <../examples/isaacorbit/torch_reach_franka_ppo.py>`
- `IsaacOrbit-Isaac-Reach-Franka-v0-PPO `_
@@ -859,22 +913,22 @@ The agent configuration is mapped, as far as possible, from the `Isaac Orbit con
|br| :download:`jax_ant_ddpg.py <../examples/isaacorbit/jax_ant_ddpg.py>`
|br| :download:`jax_ant_td3.py <../examples/isaacorbit/jax_ant_td3.py>`
|br| :download:`jax_ant_sac.py <../examples/isaacorbit/jax_ant_sac.py>`
- - |br|
+ - `IsaacOrbit-Isaac-Ant-v0-PPO `_
|br|
|br|
|br|
* - Isaac-Cartpole-v0
- :download:`jax_cartpole_ppo.py <../examples/isaacorbit/jax_cartpole_ppo.py>`
- -
+ - `IsaacOrbit-Isaac-Cartpole-v0-PPO `_
* - Isaac-Humanoid-v0
- :download:`jax_humanoid_ppo.py <../examples/isaacorbit/jax_humanoid_ppo.py>`
- -
+ - `IsaacOrbit-Isaac-Humanoid-v0-PPO `_
* - Isaac-Lift-Franka-v0
- :download:`jax_lift_franka_ppo.py <../examples/isaacorbit/jax_lift_franka_ppo.py>`
- -
+ - `IsaacOrbit-Isaac-Lift-Franka-v0-PPO `_
* - Isaac-Reach-Franka-v0
- :download:`jax_reach_franka_ppo.py <../examples/isaacorbit/jax_reach_franka_ppo.py>`
- -
+ - `IsaacOrbit-Isaac-Reach-Franka-v0-PPO `_
* - Isaac-Velocity-Anymal-C-v0
- :download:`jax_velocity_anymal_c_ppo.py <../examples/isaacorbit/jax_velocity_anymal_c_ppo.py>`
-
diff --git a/docs/source/snippets/wrapping.py b/docs/source/snippets/wrapping.py
index c816984e..4213ad57 100644
--- a/docs/source/snippets/wrapping.py
+++ b/docs/source/snippets/wrapping.py
@@ -297,6 +297,32 @@
# [jax-end-gymnasium-vectorized]
+# [pytorch-start-shimmy]
+# import the environment wrapper and gymnasium
+from skrl.envs.wrappers.torch import wrap_env
+import gymnasium as gym
+
+# load the environment (API conversion)
+env = gym.make("ALE/Pong-v5")
+
+# wrap the environment
+env = wrap_env(env) # or 'env = wrap_env(env, wrapper="gymnasium")'
+# [pytorch-end-shimmy]
+
+
+# [jax-start-shimmy]
+# import the environment wrapper and gymnasium
+from skrl.envs.wrappers.jax import wrap_env
+import gymnasium as gym
+
+# load the environment (API conversion)
+env = gym.make("ALE/Pong-v5")
+
+# wrap the environment
+env = wrap_env(env) # or 'env = wrap_env(env, wrapper="gymnasium")'
+# [jax-end-shimmy]
+
+
# [pytorch-start-deepmind]
# import the environment wrapper and the deepmind suite
from skrl.envs.wrappers.torch import wrap_env
diff --git a/pyproject.toml b/pyproject.toml
index aaeb7d3c..2afb84a0 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "skrl"
-version = "1.0.0-rc.2"
+version = "1.0.0"
description = "Modular and flexible library for reinforcement learning on PyTorch and JAX"
readme = "README.md"
requires-python = ">=3.6"
diff --git a/skrl/models/jax/base.py b/skrl/models/jax/base.py
index fa595726..30db0e83 100644
--- a/skrl/models/jax/base.py
+++ b/skrl/models/jax/base.py
@@ -129,7 +129,8 @@ def _get_space_size(self,
:param space: Space or shape from which to obtain the number of elements
:type space: int, sequence of int, gym.Space, or gymnasium.Space
:param number_of_elements: Whether the number of elements occupied by the space is returned (default: ``True``).
- If ``False``, the shape of the space is returned. It only affects Discrete spaces
+ If ``False``, the shape of the space is returned.
+ It only affects Discrete and MultiDiscrete spaces
:type number_of_elements: bool, optional
:raises ValueError: If the space is not supported
@@ -159,6 +160,13 @@ def _get_space_size(self,
>>> model._get_space_size(space, number_of_elements=False)
1
+ # MultiDiscrete space
+ >>> space = gym.spaces.MultiDiscrete([5, 3, 2])
+ >>> model._get_space_size(space)
+ 10
+ >>> model._get_space_size(space, number_of_elements=False)
+ 3
+
# Dict space
>>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)),
... 'b': gym.spaces.Discrete(4)})
@@ -178,6 +186,11 @@ def _get_space_size(self,
size = space.n
else:
size = 1
+ elif issubclass(type(space), gym.spaces.MultiDiscrete):
+ if number_of_elements:
+ size = np.sum(space.nvec)
+ else:
+ size = space.nvec.shape[0]
elif issubclass(type(space), gym.spaces.Box):
size = np.prod(space.shape)
elif issubclass(type(space), gym.spaces.Dict):
@@ -188,6 +201,11 @@ def _get_space_size(self,
size = space.n
else:
size = 1
+ elif issubclass(type(space), gymnasium.spaces.MultiDiscrete):
+ if number_of_elements:
+ size = np.sum(space.nvec)
+ else:
+ size = space.nvec.shape[0]
elif issubclass(type(space), gymnasium.spaces.Box):
size = np.prod(space.shape)
elif issubclass(type(space), gymnasium.spaces.Dict):
diff --git a/skrl/models/torch/base.py b/skrl/models/torch/base.py
index 5bddd88d..757a8ba2 100644
--- a/skrl/models/torch/base.py
+++ b/skrl/models/torch/base.py
@@ -71,7 +71,8 @@ def _get_space_size(self,
:param space: Space or shape from which to obtain the number of elements
:type space: int, sequence of int, gym.Space, or gymnasium.Space
:param number_of_elements: Whether the number of elements occupied by the space is returned (default: ``True``).
- If ``False``, the shape of the space is returned. It only affects Discrete spaces
+ If ``False``, the shape of the space is returned.
+ It only affects Discrete and MultiDiscrete spaces
:type number_of_elements: bool, optional
:raises ValueError: If the space is not supported
@@ -101,6 +102,13 @@ def _get_space_size(self,
>>> model._get_space_size(space, number_of_elements=False)
1
+ # MultiDiscrete space
+ >>> space = gym.spaces.MultiDiscrete([5, 3, 2])
+ >>> model._get_space_size(space)
+ 10
+ >>> model._get_space_size(space, number_of_elements=False)
+ 3
+
# Dict space
>>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)),
... 'b': gym.spaces.Discrete(4)})
@@ -120,6 +128,11 @@ def _get_space_size(self,
size = space.n
else:
size = 1
+ elif issubclass(type(space), gym.spaces.MultiDiscrete):
+ if number_of_elements:
+ size = np.sum(space.nvec)
+ else:
+ size = space.nvec.shape[0]
elif issubclass(type(space), gym.spaces.Box):
size = np.prod(space.shape)
elif issubclass(type(space), gym.spaces.Dict):
@@ -130,6 +143,11 @@ def _get_space_size(self,
size = space.n
else:
size = 1
+ elif issubclass(type(space), gymnasium.spaces.MultiDiscrete):
+ if number_of_elements:
+ size = np.sum(space.nvec)
+ else:
+ size = space.nvec.shape[0]
elif issubclass(type(space), gymnasium.spaces.Box):
size = np.prod(space.shape)
elif issubclass(type(space), gymnasium.spaces.Dict):