diff --git a/CHANGELOG.md b/CHANGELOG.md index 1d77eb10..745e2b14 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,14 +2,34 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [1.0.0] - 2023-08-16 + +Transition from pre-release versions (`1.0.0-rc.1` and`1.0.0-rc.2`) to a stable version. + +This release also announces the publication of the **skrl** paper in the Journal of Machine Learning Research (JMLR): https://www.jmlr.org/papers/v24/23-0112.html + +Summary of the most relevant features: +- JAX support +- New documentation theme and structure +- Multi-agent Reinforcement Learning (MARL) + ## [1.0.0-rc.2] - 2023-08-11 ### Added - Get truncation from `time_outs` info in Isaac Gym, Isaac Orbit and Omniverse Isaac Gym environments - Time-limit (truncation) boostrapping in on-policy actor-critic agents - Model instantiators `initial_log_std` parameter to set the log standard deviation's initial value +### Changed (breaking changes) +- Structure environment loaders and wrappers file hierarchy coherently + Import statements now follow the next convention: + - Wrappers (e.g.): + - `from skrl.envs.wrappers.torch import wrap_env` + - `from skrl.envs.wrappers.jax import wrap_env` + - Loaders (e.g.): + - `from skrl.envs.loaders.torch import load_omniverse_isaacgym_env` + - `from skrl.envs.loaders.jax import load_omniverse_isaacgym_env` + ### Changed -- Structure environment loaders and wrappers file hierarchy coherently [**breaking change**] - Drop support for versions prior to PyTorch 1.9 (1.8.0 and 1.8.1) ## [1.0.0-rc.1] - 2023-07-25 @@ -66,7 +86,7 @@ and Omniverse Isaac Gym environments when they are loaded ### Added - Support for Farama Gymnasium interface - Wrapper for robosuite environments -- Weights & Biases integration (by @juhannc) +- Weights & Biases integration - Set the running mode (training or evaluation) of the agents - Allow clipping the gradient norm for DDPG, TD3 and SAC agents - Initialize model biases @@ -75,9 +95,11 @@ and Omniverse Isaac Gym environments when they are loaded - Farama Shimmy and robosuite examples - KUKA LBR iiwa real-world example +### Changed (breaking changes) +- Forward model inputs as a Python dictionary +- Returns a Python dictionary with extra output values in model calls + ### Changed -- Forward model inputs as a Python dictionary [**breaking change**] -- Returns a Python dictionary with extra output values in model calls [**breaking change**] - Adopt the implementation of `terminated` and `truncated` over `done` for all environments ### Fixed @@ -98,7 +120,7 @@ to allow storing samples in memories during evaluation - Gaussian model mixin - Support for creating shared models - Parameter `role` to model methods -- Wrapper compatibility with the new OpenAI Gym environment API (by @juhannc) +- Wrapper compatibility with the new OpenAI Gym environment API - Internal library colored logger - Migrate checkpoints/models from other RL libraries to skrl models/agents - Configuration parameter `store_separately` to agent configuration dict @@ -107,11 +129,13 @@ to allow storing samples in memories during evaluation - Benchmark results for Isaac Gym and Omniverse Isaac Gym on the GitHub discussion page - Franka Emika real-world example +### Changed (breaking changes) +- Models implementation as Python mixin + ### Changed -- Models implementation as Python mixin [**breaking change**] - Multivariate Gaussian model (`GaussianModel` until 0.7.0) to `MultivariateGaussianMixin` - Trainer's `cfg` parameter position and default values -- Show training/evaluation display progress using `tqdm` (by @juhannc) +- Show training/evaluation display progress using `tqdm` - Update Isaac Gym and Omniverse Isaac Gym examples ### Fixed diff --git a/README.md b/README.md index 49e63854..65b7b391 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,9 @@

+ +

SKRL - Reinforcement Learning library


@@ -21,7 +23,7 @@ ### Please, visit the documentation for usage details and examples -https://skrl.readthedocs.io/en/latest/ +https://skrl.readthedocs.io
@@ -34,10 +36,14 @@ https://skrl.readthedocs.io/en/latest/ To cite this library in publications, please use the following reference: ```bibtex -@article{serrano2022skrl, - title={skrl: Modular and Flexible Library for Reinforcement Learning}, - author={Serrano-Mu{\~n}oz, Antonio and Arana-Arexolaleiba, Nestor and Chrysostomou, Dimitrios and B{\o}gh, Simon}, - journal={arXiv preprint arXiv:2202.03825}, - year={2022} +@article{serrano2023skrl, + author = {Antonio Serrano-Muñoz and Dimitrios Chrysostomou and Simon Bøgh and Nestor Arana-Arexolaleiba}, + title = {skrl: Modular and Flexible Library for Reinforcement Learning}, + journal = {Journal of Machine Learning Research}, + year = {2023}, + volume = {24}, + number = {254}, + pages = {1--9}, + url = {http://jmlr.org/papers/v24/23-0112.html} } ``` diff --git a/docs/source/api/envs.rst b/docs/source/api/envs.rst index 55e23c37..3b3091da 100644 --- a/docs/source/api/envs.rst +++ b/docs/source/api/envs.rst @@ -69,3 +69,6 @@ In addition, you will be able to :doc:`wrap single-agent ` and :d * - robosuite - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Shimmy + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\blacksquare` diff --git a/docs/source/api/envs/wrapping.rst b/docs/source/api/envs/wrapping.rst index a4454cd3..a725a7e9 100644 --- a/docs/source/api/envs/wrapping.rst +++ b/docs/source/api/envs/wrapping.rst @@ -10,6 +10,7 @@ Wrapping (single-agent) This library works with a common API to interact with the following RL environments: * OpenAI `Gym `_ / Farama `Gymnasium `_ (single and vectorized environments) +* `Farama Shimmy `_ * `DeepMind `_ * `robosuite `_ * `NVIDIA Isaac Gym `_ (preview 2, 3 and 4) @@ -265,6 +266,24 @@ Usage :start-after: [jax-start-gymnasium-vectorized] :end-before: [jax-end-gymnasium-vectorized] + .. tab:: Shimmy + + .. tabs:: + + .. group-tab:: |_4| |pytorch| |_4| + + .. literalinclude:: ../../snippets/wrapping.py + :language: python + :start-after: [pytorch-start-shimmy] + :end-before: [pytorch-end-shimmy] + + .. group-tab:: |_4| |jax| |_4| + + .. literalinclude:: ../../snippets/wrapping.py + :language: python + :start-after: [jax-start-shimmy] + :end-before: [jax-end-shimmy] + .. tab:: DeepMind .. tabs:: diff --git a/docs/source/conf.py b/docs/source/conf.py index 8882fe6d..42f37f8a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -16,7 +16,7 @@ if skrl.__version__ != "unknown": release = version = skrl.__version__ else: - release = version = "1.0.0-rc.2" + release = version = "1.0.0" master_doc = "index" diff --git a/docs/source/examples/isaacorbit/jax_ant_ppo.py b/docs/source/examples/isaacorbit/jax_ant_ppo.py index 5ad99973..ba0cc9f3 100644 --- a/docs/source/examples/isaacorbit/jax_ant_ppo.py +++ b/docs/source/examples/isaacorbit/jax_ant_ppo.py @@ -115,7 +115,8 @@ def __call__(self, inputs, role): cfg["entropy_loss_scale"] = 0.0 cfg["value_loss_scale"] = 1.0 cfg["kl_threshold"] = 0 -cfg["rewards_shaper"] = lambda rewards, timestep, timesteps: rewards * 0.01 +cfg["rewards_shaper"] = lambda rewards, *args, **kwargs: rewards * 0.1 +cfg["time_limit_bootstrap"] = True cfg["state_preprocessor"] = RunningStandardScaler cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device} cfg["value_preprocessor"] = RunningStandardScaler @@ -139,3 +140,17 @@ def __call__(self, inputs, role): # start training trainer.train() + + +# # --------------------------------------------------------- +# # comment the code above: `trainer.train()`, and... +# # uncomment the following lines to evaluate a trained agent +# # --------------------------------------------------------- +# from skrl.utils.huggingface import download_model_from_huggingface + +# # download the trained agent's checkpoint from Hugging Face Hub and load it +# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Ant-v0-PPO", filename="agent.pickle") +# agent.load(path) + +# # start evaluation +# trainer.eval() diff --git a/docs/source/examples/isaacorbit/jax_cartpole_ppo.py b/docs/source/examples/isaacorbit/jax_cartpole_ppo.py index fedfb602..d6d5ef1b 100644 --- a/docs/source/examples/isaacorbit/jax_cartpole_ppo.py +++ b/docs/source/examples/isaacorbit/jax_cartpole_ppo.py @@ -51,7 +51,7 @@ def __call__(self, inputs, role): x = nn.elu(nn.Dense(32)(inputs["states"])) x = nn.elu(nn.Dense(32)(x)) x = nn.Dense(self.num_actions)(x) - log_std = self.param("log_std", lambda _: jnp.zeros(self.num_actions)) + log_std = self.param("log_std", lambda _: jnp.ones(self.num_actions)) return x, log_std, {} class Value(DeterministicMixin, Model): @@ -114,6 +114,7 @@ def __call__(self, inputs, role): cfg["value_loss_scale"] = 2.0 cfg["kl_threshold"] = 0 cfg["rewards_shaper"] = None +cfg["time_limit_bootstrap"] = True cfg["state_preprocessor"] = RunningStandardScaler cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device} cfg["value_preprocessor"] = RunningStandardScaler @@ -137,3 +138,17 @@ def __call__(self, inputs, role): # start training trainer.train() + + +# # --------------------------------------------------------- +# # comment the code above: `trainer.train()`, and... +# # uncomment the following lines to evaluate a trained agent +# # --------------------------------------------------------- +# from skrl.utils.huggingface import download_model_from_huggingface + +# # download the trained agent's checkpoint from Hugging Face Hub and load it +# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Cartpole-v0-PPO", filename="agent.pickle") +# agent.load(path) + +# # start evaluation +# trainer.eval() diff --git a/docs/source/examples/isaacorbit/jax_humanoid_ppo.py b/docs/source/examples/isaacorbit/jax_humanoid_ppo.py index 6405b970..dafdb04b 100644 --- a/docs/source/examples/isaacorbit/jax_humanoid_ppo.py +++ b/docs/source/examples/isaacorbit/jax_humanoid_ppo.py @@ -39,7 +39,7 @@ # define models (stochastic and deterministic models) using mixins class Policy(GaussianMixin, Model): def __init__(self, observation_space, action_space, device=None, clip_actions=False, - clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum", **kwargs): + clip_log_std=True, min_log_std=-5, max_log_std=2, reduction="sum", **kwargs): Model.__init__(self, observation_space, action_space, device, **kwargs) GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction) @@ -53,7 +53,7 @@ def __call__(self, inputs, role): x = nn.elu(nn.Dense(100)(x)) x = nn.Dense(self.num_actions)(x) log_std = self.param("log_std", lambda _: jnp.zeros(self.num_actions)) - return x, log_std, {} + return nn.tanh(x), log_std, {} class Value(DeterministicMixin, Model): def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs): @@ -105,7 +105,7 @@ def __call__(self, inputs, role): cfg["lambda"] = 0.95 cfg["learning_rate"] = 3e-4 cfg["learning_rate_scheduler"] = KLAdaptiveRL -cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.008} +cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01} cfg["random_timesteps"] = 0 cfg["learning_starts"] = 0 cfg["grad_norm_clip"] = 1.0 @@ -113,9 +113,10 @@ def __call__(self, inputs, role): cfg["value_clip"] = 0.2 cfg["clip_predicted_values"] = True cfg["entropy_loss_scale"] = 0.0 -cfg["value_loss_scale"] = 1.0 +cfg["value_loss_scale"] = 4.0 cfg["kl_threshold"] = 0 -cfg["rewards_shaper"] = lambda rewards, timestep, timesteps: rewards * 0.01 +cfg["rewards_shaper"] = lambda rewards, *args, **kwargs: rewards * 0.01 +cfg["time_limit_bootstrap"] = False cfg["state_preprocessor"] = RunningStandardScaler cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device} cfg["value_preprocessor"] = RunningStandardScaler @@ -139,3 +140,17 @@ def __call__(self, inputs, role): # start training trainer.train() + + +# # --------------------------------------------------------- +# # comment the code above: `trainer.train()`, and... +# # uncomment the following lines to evaluate a trained agent +# # --------------------------------------------------------- +# from skrl.utils.huggingface import download_model_from_huggingface + +# # download the trained agent's checkpoint from Hugging Face Hub and load it +# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Humanoid-v0-PPO", filename="agent.pickle") +# agent.load(path) + +# # start evaluation +# trainer.eval() diff --git a/docs/source/examples/isaacorbit/jax_lift_franka_ppo.py b/docs/source/examples/isaacorbit/jax_lift_franka_ppo.py index ca4945df..b2d24b69 100644 --- a/docs/source/examples/isaacorbit/jax_lift_franka_ppo.py +++ b/docs/source/examples/isaacorbit/jax_lift_franka_ppo.py @@ -122,8 +122,8 @@ def __call__(self, inputs, role): cfg["value_preprocessor"] = RunningStandardScaler cfg["value_preprocessor_kwargs"] = {"size": 1, "device": device} # logging to TensorBoard and write checkpoints (in timesteps) -cfg["experiment"]["write_interval"] = 800 -cfg["experiment"]["checkpoint_interval"] = 8000 +cfg["experiment"]["write_interval"] = 336 +cfg["experiment"]["checkpoint_interval"] = 3360 cfg["experiment"]["directory"] = "runs/jax/Isaac-Lift-Franka-v0" agent = PPO(models=models, @@ -140,3 +140,17 @@ def __call__(self, inputs, role): # start training trainer.train() + + +# # --------------------------------------------------------- +# # comment the code above: `trainer.train()`, and... +# # uncomment the following lines to evaluate a trained agent +# # --------------------------------------------------------- +# from skrl.utils.huggingface import download_model_from_huggingface + +# # download the trained agent's checkpoint from Hugging Face Hub and load it +# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Lift-Franka-v0-PPO", filename="agent.pickle") +# agent.load(path) + +# # start evaluation +# trainer.eval() diff --git a/docs/source/examples/isaacorbit/jax_reach_franka_ppo.py b/docs/source/examples/isaacorbit/jax_reach_franka_ppo.py index 210c4ab9..2b01af30 100644 --- a/docs/source/examples/isaacorbit/jax_reach_franka_ppo.py +++ b/docs/source/examples/isaacorbit/jax_reach_franka_ppo.py @@ -52,8 +52,8 @@ def __call__(self, inputs, role): x = nn.elu(nn.Dense(128)(x)) x = nn.elu(nn.Dense(64)(x)) x = nn.Dense(self.num_actions)(x) - log_std = self.param("log_std", lambda _: jnp.zeros(self.num_actions)) - return x, log_std, {} + log_std = self.param("log_std", lambda _: 0.5 * jnp.ones(self.num_actions)) + return nn.tanh(x), log_std, {} class Value(DeterministicMixin, Model): def __init__(self, observation_space, action_space, device=None, clip_actions=False, **kwargs): @@ -105,7 +105,7 @@ def __call__(self, inputs, role): cfg["lambda"] = 0.95 cfg["learning_rate"] = 3e-4 cfg["learning_rate_scheduler"] = KLAdaptiveRL -cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.008} +cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01} cfg["random_timesteps"] = 0 cfg["learning_starts"] = 0 cfg["grad_norm_clip"] = 1.0 @@ -116,6 +116,7 @@ def __call__(self, inputs, role): cfg["value_loss_scale"] = 2.0 cfg["kl_threshold"] = 0 cfg["rewards_shaper"] = None +cfg["time_limit_bootstrap"] = False cfg["state_preprocessor"] = RunningStandardScaler cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device} cfg["value_preprocessor"] = RunningStandardScaler @@ -139,3 +140,17 @@ def __call__(self, inputs, role): # start training trainer.train() + + +# # --------------------------------------------------------- +# # comment the code above: `trainer.train()`, and... +# # uncomment the following lines to evaluate a trained agent +# # --------------------------------------------------------- +# from skrl.utils.huggingface import download_model_from_huggingface + +# # download the trained agent's checkpoint from Hugging Face Hub and load it +# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Reach-Franka-v0-PPO", filename="agent.pickle") +# agent.load(path) + +# # start evaluation +# trainer.eval() diff --git a/docs/source/examples/isaacorbit/jax_velocity_anymal_c_ppo.py b/docs/source/examples/isaacorbit/jax_velocity_anymal_c_ppo.py index 6711c59b..0238a6ec 100644 --- a/docs/source/examples/isaacorbit/jax_velocity_anymal_c_ppo.py +++ b/docs/source/examples/isaacorbit/jax_velocity_anymal_c_ppo.py @@ -112,10 +112,11 @@ def __call__(self, inputs, role): cfg["ratio_clip"] = 0.2 cfg["value_clip"] = 0.2 cfg["clip_predicted_values"] = True -cfg["entropy_loss_scale"] = 0.01 +cfg["entropy_loss_scale"] = 0.0 cfg["value_loss_scale"] = 1.0 cfg["kl_threshold"] = 0 cfg["rewards_shaper"] = None +cfg["time_limit_bootstrap"] = False cfg["state_preprocessor"] = RunningStandardScaler cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device} cfg["value_preprocessor"] = RunningStandardScaler diff --git a/docs/source/examples/isaacorbit/torch_ant_ppo.py b/docs/source/examples/isaacorbit/torch_ant_ppo.py index 92e986a6..54f337a8 100644 --- a/docs/source/examples/isaacorbit/torch_ant_ppo.py +++ b/docs/source/examples/isaacorbit/torch_ant_ppo.py @@ -89,7 +89,8 @@ def compute(self, inputs, role): cfg["entropy_loss_scale"] = 0.0 cfg["value_loss_scale"] = 1.0 cfg["kl_threshold"] = 0 -cfg["rewards_shaper"] = lambda rewards, timestep, timesteps: rewards * 0.01 +cfg["rewards_shaper"] = lambda rewards, *args, **kwargs: rewards * 0.1 +cfg["time_limit_bootstrap"] = True cfg["state_preprocessor"] = RunningStandardScaler cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device} cfg["value_preprocessor"] = RunningStandardScaler diff --git a/docs/source/examples/isaacorbit/torch_cartpole_ppo.py b/docs/source/examples/isaacorbit/torch_cartpole_ppo.py index 51da382e..cd0d3f78 100644 --- a/docs/source/examples/isaacorbit/torch_cartpole_ppo.py +++ b/docs/source/examples/isaacorbit/torch_cartpole_ppo.py @@ -31,7 +31,7 @@ def __init__(self, observation_space, action_space, device, clip_actions=False, nn.ELU()) self.mean_layer = nn.Linear(32, self.num_actions) - self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions)) + self.log_std_parameter = nn.Parameter(torch.ones(self.num_actions)) self.value_layer = nn.Linear(32, 1) @@ -88,6 +88,7 @@ def compute(self, inputs, role): cfg["value_loss_scale"] = 2.0 cfg["kl_threshold"] = 0 cfg["rewards_shaper"] = None +cfg["time_limit_bootstrap"] = True cfg["state_preprocessor"] = RunningStandardScaler cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device} cfg["value_preprocessor"] = RunningStandardScaler diff --git a/docs/source/examples/isaacorbit/torch_humanoid_ppo.py b/docs/source/examples/isaacorbit/torch_humanoid_ppo.py index 09f9b974..c3e8f876 100644 --- a/docs/source/examples/isaacorbit/torch_humanoid_ppo.py +++ b/docs/source/examples/isaacorbit/torch_humanoid_ppo.py @@ -20,7 +20,7 @@ # define shared model (stochastic and deterministic models) using mixins class Shared(GaussianMixin, DeterministicMixin, Model): def __init__(self, observation_space, action_space, device, clip_actions=False, - clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"): + clip_log_std=True, min_log_std=-5, max_log_std=2, reduction="sum"): Model.__init__(self, observation_space, action_space, device) GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction) DeterministicMixin.__init__(self, clip_actions) @@ -45,7 +45,7 @@ def act(self, inputs, role): def compute(self, inputs, role): if role == "policy": - return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {} + return torch.tanh(self.mean_layer(self.net(inputs["states"]))), self.log_std_parameter, {} elif role == "value": return self.value_layer(self.net(inputs["states"])), {} @@ -79,7 +79,7 @@ def compute(self, inputs, role): cfg["lambda"] = 0.95 cfg["learning_rate"] = 3e-4 cfg["learning_rate_scheduler"] = KLAdaptiveRL -cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.008} +cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01} cfg["random_timesteps"] = 0 cfg["learning_starts"] = 0 cfg["grad_norm_clip"] = 1.0 @@ -87,9 +87,10 @@ def compute(self, inputs, role): cfg["value_clip"] = 0.2 cfg["clip_predicted_values"] = True cfg["entropy_loss_scale"] = 0.0 -cfg["value_loss_scale"] = 1.0 +cfg["value_loss_scale"] = 4.0 cfg["kl_threshold"] = 0 -cfg["rewards_shaper"] = lambda rewards, timestep, timesteps: rewards * 0.01 +cfg["rewards_shaper"] = lambda rewards, *args, **kwargs: rewards * 0.01 +cfg["time_limit_bootstrap"] = False cfg["state_preprocessor"] = RunningStandardScaler cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device} cfg["value_preprocessor"] = RunningStandardScaler diff --git a/docs/source/examples/isaacorbit/torch_lift_franka_ppo.py b/docs/source/examples/isaacorbit/torch_lift_franka_ppo.py index 5f3f3745..406184b9 100644 --- a/docs/source/examples/isaacorbit/torch_lift_franka_ppo.py +++ b/docs/source/examples/isaacorbit/torch_lift_franka_ppo.py @@ -96,8 +96,8 @@ def compute(self, inputs, role): cfg["value_preprocessor"] = RunningStandardScaler cfg["value_preprocessor_kwargs"] = {"size": 1, "device": device} # logging to TensorBoard and write checkpoints (in timesteps) -cfg["experiment"]["write_interval"] = 800 -cfg["experiment"]["checkpoint_interval"] = 8000 +cfg["experiment"]["write_interval"] = 336 +cfg["experiment"]["checkpoint_interval"] = 3360 cfg["experiment"]["directory"] = "runs/torch/Isaac-Lift-Franka-v0" agent = PPO(models=models, @@ -114,3 +114,17 @@ def compute(self, inputs, role): # start training trainer.train() + + +# # --------------------------------------------------------- +# # comment the code above: `trainer.train()`, and... +# # uncomment the following lines to evaluate a trained agent +# # --------------------------------------------------------- +# from skrl.utils.huggingface import download_model_from_huggingface + +# # download the trained agent's checkpoint from Hugging Face Hub and load it +# path = download_model_from_huggingface("skrl/IsaacOrbit-Isaac-Lift-Franka-v0-PPO", filename="agent.pt") +# agent.load(path) + +# # start evaluation +# trainer.eval() diff --git a/docs/source/examples/isaacorbit/torch_reach_franka_ppo.py b/docs/source/examples/isaacorbit/torch_reach_franka_ppo.py index 7bb8be7e..bc747c5a 100644 --- a/docs/source/examples/isaacorbit/torch_reach_franka_ppo.py +++ b/docs/source/examples/isaacorbit/torch_reach_franka_ppo.py @@ -33,7 +33,7 @@ def __init__(self, observation_space, action_space, device, clip_actions=False, nn.ELU()) self.mean_layer = nn.Linear(64, self.num_actions) - self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions)) + self.log_std_parameter = nn.Parameter(0.5 * torch.ones(self.num_actions)) self.value_layer = nn.Linear(64, 1) @@ -45,7 +45,7 @@ def act(self, inputs, role): def compute(self, inputs, role): if role == "policy": - return self.mean_layer(self.net(inputs["states"])), self.log_std_parameter, {} + return torch.tanh(self.mean_layer(self.net(inputs["states"]))), self.log_std_parameter, {} elif role == "value": return self.value_layer(self.net(inputs["states"])), {} @@ -79,7 +79,7 @@ def compute(self, inputs, role): cfg["lambda"] = 0.95 cfg["learning_rate"] = 3e-4 cfg["learning_rate_scheduler"] = KLAdaptiveRL -cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.008} +cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01} cfg["random_timesteps"] = 0 cfg["learning_starts"] = 0 cfg["grad_norm_clip"] = 1.0 @@ -90,6 +90,7 @@ def compute(self, inputs, role): cfg["value_loss_scale"] = 2.0 cfg["kl_threshold"] = 0 cfg["rewards_shaper"] = None +cfg["time_limit_bootstrap"] = False cfg["state_preprocessor"] = RunningStandardScaler cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device} cfg["value_preprocessor"] = RunningStandardScaler diff --git a/docs/source/examples/isaacorbit/torch_velocity_anymal_c_ppo.py b/docs/source/examples/isaacorbit/torch_velocity_anymal_c_ppo.py index 4eaa7d9b..d5224a05 100644 --- a/docs/source/examples/isaacorbit/torch_velocity_anymal_c_ppo.py +++ b/docs/source/examples/isaacorbit/torch_velocity_anymal_c_ppo.py @@ -86,10 +86,11 @@ def compute(self, inputs, role): cfg["ratio_clip"] = 0.2 cfg["value_clip"] = 0.2 cfg["clip_predicted_values"] = True -cfg["entropy_loss_scale"] = 0.01 +cfg["entropy_loss_scale"] = 0.0 cfg["value_loss_scale"] = 1.0 cfg["kl_threshold"] = 0 cfg["rewards_shaper"] = None +cfg["time_limit_bootstrap"] = False cfg["state_preprocessor"] = RunningStandardScaler cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device} cfg["value_preprocessor"] = RunningStandardScaler diff --git a/docs/source/index.rst b/docs/source/index.rst index f1cd8aec..5ab4e515 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -50,15 +50,19 @@ SKRL - Reinforcement Learning library (|version|) | **Questions or discussions:** https://github.com/Toni-SM/skrl/discussions | -**Citing skrl:** To cite this library (created at `Mondragon Unibertsitatea `_) use the following reference to its `article `_: *"skrl: Modular and Flexible Library for Reinforcement Learning"* +**Citing skrl:** To cite this library (created at Mondragon Unibertsitatea) use the following reference to its article: `skrl: Modular and Flexible Library for Reinforcement Learning `_. .. code-block:: bibtex - @article{serrano2022skrl, - title={skrl: Modular and Flexible Library for Reinforcement Learning}, - author={Serrano-Mu{\~n}oz, Antonio and Arana-Arexolaleiba, Nestor and Chrysostomou, Dimitrios and B{\o}gh, Simon}, - journal={arXiv preprint arXiv:2202.03825}, - year={2022} + @article{serrano2023skrl, + author = {Antonio Serrano-Muñoz and Dimitrios Chrysostomou and Simon Bøgh and Nestor Arana-Arexolaleiba}, + title = {skrl: Modular and Flexible Library for Reinforcement Learning}, + journal = {Journal of Machine Learning Research}, + year = {2023}, + volume = {24}, + number = {254}, + pages = {1--9}, + url = {http://jmlr.org/papers/v24/23-0112.html} } .. raw:: html diff --git a/docs/source/intro/examples.rst b/docs/source/intro/examples.rst index 6dac2e8e..2f16383e 100644 --- a/docs/source/intro/examples.rst +++ b/docs/source/intro/examples.rst @@ -98,12 +98,33 @@ Training/evaluation of an agent in `Gymnasium `_ - :download:`jax_gymnasium_cartpole_cem.py <../examples/gymnasium/jax_gymnasium_cartpole_cem.py>` |br| :download:`jax_gymnasium_cartpole_dqn.py <../examples/gymnasium/jax_gymnasium_cartpole_dqn.py>` - + * - FrozenLake + - + - * - Pendulum - :download:`jax_gymnasium_pendulum_ddpg.py <../examples/gymnasium/jax_gymnasium_pendulum_ddpg.py>` |br| :download:`jax_gymnasium_pendulum_ppo.py <../examples/gymnasium/jax_gymnasium_pendulum_ppo.py>` |br| :download:`jax_gymnasium_pendulum_sac.py <../examples/gymnasium/jax_gymnasium_pendulum_sac.py>` |br| :download:`jax_gymnasium_pendulum_td3.py <../examples/gymnasium/jax_gymnasium_pendulum_td3.py>` - + * - PendulumNoVel* + |br| (RNN / GRU / LSTM) + - |br| + |br| + |br| + |br| + |br| + |br| + |br| + |br| + |br| + |br| + |br| + |br| + - + * - Taxi + - + - .. note:: @@ -171,12 +192,33 @@ Training/evaluation of an agent in `Gymnasium `_ - :download:`jax_gym_cartpole_cem.py <../examples/gym/jax_gym_cartpole_cem.py>` |br| :download:`jax_gym_cartpole_dqn.py <../examples/gym/jax_gym_cartpole_dqn.py>` - + * - FrozenLake + - + - * - Pendulum - :download:`jax_gym_pendulum_ddpg.py <../examples/gym/jax_gym_pendulum_ddpg.py>` |br| :download:`jax_gym_pendulum_ppo.py <../examples/gym/jax_gym_pendulum_ppo.py>` |br| :download:`jax_gym_pendulum_sac.py <../examples/gym/jax_gym_pendulum_sac.py>` |br| :download:`jax_gym_pendulum_td3.py <../examples/gym/jax_gym_pendulum_td3.py>` - + * - PendulumNoVel* + |br| (RNN / GRU / LSTM) + - |br| + |br| + |br| + |br| + |br| + |br| + |br| + |br| + |br| + |br| + |br| + |br| + - + * - Taxi + - + - .. note:: @@ -235,9 +277,15 @@ Training/evaluation of an agent in `Gymnasium `_ * - CartPole - :download:`jax_gymnasium_cartpole_vector_dqn.py <../examples/gymnasium/jax_gymnasium_cartpole_vector_dqn.py>` - + * - FrozenLake + - + - * - Pendulum - :download:`jax_gymnasium_pendulum_vector_ddpg.py <../examples/gymnasium/jax_gymnasium_pendulum_vector_ddpg.py>` - + * - Taxi + - + - .. group-tab:: Gym @@ -281,9 +329,15 @@ Training/evaluation of an agent in `Gymnasium `_ * - CartPole - :download:`jax_gym_cartpole_vector_dqn.py <../examples/gym/jax_gym_cartpole_vector_dqn.py>` - + * - FrozenLake + - + - * - Pendulum - :download:`jax_gym_pendulum_vector_ddpg.py <../examples/gym/jax_gym_pendulum_vector_ddpg.py>` - + * - Taxi + - + - .. raw:: html @@ -835,7 +889,7 @@ The agent configuration is mapped, as far as possible, from the `Isaac Orbit con - `IsaacOrbit-Isaac-Humanoid-v0-PPO `_ * - Isaac-Lift-Franka-v0 - :download:`torch_lift_franka_ppo.py <../examples/isaacorbit/torch_lift_franka_ppo.py>` - - + - `IsaacOrbit-Isaac-Lift-Franka-v0-PPO `_ * - Isaac-Reach-Franka-v0 - :download:`torch_reach_franka_ppo.py <../examples/isaacorbit/torch_reach_franka_ppo.py>` - `IsaacOrbit-Isaac-Reach-Franka-v0-PPO `_ @@ -859,22 +913,22 @@ The agent configuration is mapped, as far as possible, from the `Isaac Orbit con |br| :download:`jax_ant_ddpg.py <../examples/isaacorbit/jax_ant_ddpg.py>` |br| :download:`jax_ant_td3.py <../examples/isaacorbit/jax_ant_td3.py>` |br| :download:`jax_ant_sac.py <../examples/isaacorbit/jax_ant_sac.py>` - - |br| + - `IsaacOrbit-Isaac-Ant-v0-PPO `_ |br| |br| |br| * - Isaac-Cartpole-v0 - :download:`jax_cartpole_ppo.py <../examples/isaacorbit/jax_cartpole_ppo.py>` - - + - `IsaacOrbit-Isaac-Cartpole-v0-PPO `_ * - Isaac-Humanoid-v0 - :download:`jax_humanoid_ppo.py <../examples/isaacorbit/jax_humanoid_ppo.py>` - - + - `IsaacOrbit-Isaac-Humanoid-v0-PPO `_ * - Isaac-Lift-Franka-v0 - :download:`jax_lift_franka_ppo.py <../examples/isaacorbit/jax_lift_franka_ppo.py>` - - + - `IsaacOrbit-Isaac-Lift-Franka-v0-PPO `_ * - Isaac-Reach-Franka-v0 - :download:`jax_reach_franka_ppo.py <../examples/isaacorbit/jax_reach_franka_ppo.py>` - - + - `IsaacOrbit-Isaac-Reach-Franka-v0-PPO `_ * - Isaac-Velocity-Anymal-C-v0 - :download:`jax_velocity_anymal_c_ppo.py <../examples/isaacorbit/jax_velocity_anymal_c_ppo.py>` - diff --git a/docs/source/snippets/wrapping.py b/docs/source/snippets/wrapping.py index c816984e..4213ad57 100644 --- a/docs/source/snippets/wrapping.py +++ b/docs/source/snippets/wrapping.py @@ -297,6 +297,32 @@ # [jax-end-gymnasium-vectorized] +# [pytorch-start-shimmy] +# import the environment wrapper and gymnasium +from skrl.envs.wrappers.torch import wrap_env +import gymnasium as gym + +# load the environment (API conversion) +env = gym.make("ALE/Pong-v5") + +# wrap the environment +env = wrap_env(env) # or 'env = wrap_env(env, wrapper="gymnasium")' +# [pytorch-end-shimmy] + + +# [jax-start-shimmy] +# import the environment wrapper and gymnasium +from skrl.envs.wrappers.jax import wrap_env +import gymnasium as gym + +# load the environment (API conversion) +env = gym.make("ALE/Pong-v5") + +# wrap the environment +env = wrap_env(env) # or 'env = wrap_env(env, wrapper="gymnasium")' +# [jax-end-shimmy] + + # [pytorch-start-deepmind] # import the environment wrapper and the deepmind suite from skrl.envs.wrappers.torch import wrap_env diff --git a/pyproject.toml b/pyproject.toml index aaeb7d3c..2afb84a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "skrl" -version = "1.0.0-rc.2" +version = "1.0.0" description = "Modular and flexible library for reinforcement learning on PyTorch and JAX" readme = "README.md" requires-python = ">=3.6" diff --git a/skrl/models/jax/base.py b/skrl/models/jax/base.py index fa595726..30db0e83 100644 --- a/skrl/models/jax/base.py +++ b/skrl/models/jax/base.py @@ -129,7 +129,8 @@ def _get_space_size(self, :param space: Space or shape from which to obtain the number of elements :type space: int, sequence of int, gym.Space, or gymnasium.Space :param number_of_elements: Whether the number of elements occupied by the space is returned (default: ``True``). - If ``False``, the shape of the space is returned. It only affects Discrete spaces + If ``False``, the shape of the space is returned. + It only affects Discrete and MultiDiscrete spaces :type number_of_elements: bool, optional :raises ValueError: If the space is not supported @@ -159,6 +160,13 @@ def _get_space_size(self, >>> model._get_space_size(space, number_of_elements=False) 1 + # MultiDiscrete space + >>> space = gym.spaces.MultiDiscrete([5, 3, 2]) + >>> model._get_space_size(space) + 10 + >>> model._get_space_size(space, number_of_elements=False) + 3 + # Dict space >>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)), ... 'b': gym.spaces.Discrete(4)}) @@ -178,6 +186,11 @@ def _get_space_size(self, size = space.n else: size = 1 + elif issubclass(type(space), gym.spaces.MultiDiscrete): + if number_of_elements: + size = np.sum(space.nvec) + else: + size = space.nvec.shape[0] elif issubclass(type(space), gym.spaces.Box): size = np.prod(space.shape) elif issubclass(type(space), gym.spaces.Dict): @@ -188,6 +201,11 @@ def _get_space_size(self, size = space.n else: size = 1 + elif issubclass(type(space), gymnasium.spaces.MultiDiscrete): + if number_of_elements: + size = np.sum(space.nvec) + else: + size = space.nvec.shape[0] elif issubclass(type(space), gymnasium.spaces.Box): size = np.prod(space.shape) elif issubclass(type(space), gymnasium.spaces.Dict): diff --git a/skrl/models/torch/base.py b/skrl/models/torch/base.py index 5bddd88d..757a8ba2 100644 --- a/skrl/models/torch/base.py +++ b/skrl/models/torch/base.py @@ -71,7 +71,8 @@ def _get_space_size(self, :param space: Space or shape from which to obtain the number of elements :type space: int, sequence of int, gym.Space, or gymnasium.Space :param number_of_elements: Whether the number of elements occupied by the space is returned (default: ``True``). - If ``False``, the shape of the space is returned. It only affects Discrete spaces + If ``False``, the shape of the space is returned. + It only affects Discrete and MultiDiscrete spaces :type number_of_elements: bool, optional :raises ValueError: If the space is not supported @@ -101,6 +102,13 @@ def _get_space_size(self, >>> model._get_space_size(space, number_of_elements=False) 1 + # MultiDiscrete space + >>> space = gym.spaces.MultiDiscrete([5, 3, 2]) + >>> model._get_space_size(space) + 10 + >>> model._get_space_size(space, number_of_elements=False) + 3 + # Dict space >>> space = gym.spaces.Dict({'a': gym.spaces.Box(low=-1, high=1, shape=(2, 3)), ... 'b': gym.spaces.Discrete(4)}) @@ -120,6 +128,11 @@ def _get_space_size(self, size = space.n else: size = 1 + elif issubclass(type(space), gym.spaces.MultiDiscrete): + if number_of_elements: + size = np.sum(space.nvec) + else: + size = space.nvec.shape[0] elif issubclass(type(space), gym.spaces.Box): size = np.prod(space.shape) elif issubclass(type(space), gym.spaces.Dict): @@ -130,6 +143,11 @@ def _get_space_size(self, size = space.n else: size = 1 + elif issubclass(type(space), gymnasium.spaces.MultiDiscrete): + if number_of_elements: + size = np.sum(space.nvec) + else: + size = space.nvec.shape[0] elif issubclass(type(space), gymnasium.spaces.Box): size = np.prod(space.shape) elif issubclass(type(space), gymnasium.spaces.Dict):