Skip to content

Commit

Permalink
Merge pull request #53 from Toni-SM/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
Toni-SM committed Jan 22, 2023
2 parents fbb19d9 + bfcc4f8 commit 930b8d7
Show file tree
Hide file tree
Showing 23 changed files with 1,211 additions and 18 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [0.10.0] - 2023-01-22
### Added
- Isaac Orbit environment loader
- Wrap an Isaac Orbit environment
- Gaussian-Deterministic shared model instantiator

## [0.9.1] - 2023-01-17
### Added
- Utility for downloading models from Hugging Face Hub
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[![pypi](https://img.shields.io/pypi/v/skrl)](https://pypi.org/project/skrl)
[<img src="https://img.shields.io/badge/%F0%9F%A4%97%20Models-Huggingface-F8D521">](https://huggingface.co/skrl)
[<img src="https://img.shields.io/badge/%F0%9F%A4%97%20models-hugging%20face-F8D521">](https://huggingface.co/skrl)
![discussions](https://img.shields.io/github/discussions/Toni-SM/skrl)
<br>
[![license](https://img.shields.io/github/license/Toni-SM/skrl)](https://github.com/Toni-SM/skrl)
Expand All @@ -15,7 +15,7 @@
<h2 align="center" style="border-bottom: 0 !important;">SKRL - Reinforcement Learning library</h2>
<br>

**skrl** is an open-source modular library for Reinforcement Learning written in Python (using [PyTorch](https://pytorch.org/)) and designed with a focus on readability, simplicity, and transparency of algorithm implementation. In addition to supporting the OpenAI [Gym](https://www.gymlibrary.dev) / Farama [Gymnasium](https://gymnasium.farama.org) and [DeepMind](https://github.com/deepmind/dm_env) environment interfaces, it allows loading and configuring [NVIDIA Isaac Gym](https://developer.nvidia.com/isaac-gym/) and [NVIDIA Omniverse Isaac Gym](https://docs.omniverse.nvidia.com/app_isaacsim/app_isaacsim/tutorial_gym_isaac_gym.html) environments, enabling agents' simultaneous training by scopes (subsets of environments among all available environments), which may or may not share resources, in the same run
**skrl** is an open-source modular library for Reinforcement Learning written in Python (using [PyTorch](https://pytorch.org/)) and designed with a focus on readability, simplicity, and transparency of algorithm implementation. In addition to supporting the OpenAI [Gym](https://www.gymlibrary.dev) / Farama [Gymnasium](https://gymnasium.farama.org) and [DeepMind](https://github.com/deepmind/dm_env) and other environment interfaces, it allows loading and configuring [NVIDIA Isaac Gym](https://developer.nvidia.com/isaac-gym/), [NVIDIA Isaac Orbit](https://isaac-orbit.github.io/orbit/index.html) and [NVIDIA Omniverse Isaac Gym](https://docs.omniverse.nvidia.com/app_isaacsim/app_isaacsim/tutorial_gym_isaac_gym.html) environments, enabling agents' simultaneous training by scopes (subsets of environments among all available environments), which may or may not share resources, in the same run

<br>

Expand Down
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ sphinx
sphinx_rtd_theme
sphinx-autobuild
sphinx-tabs==3.2.0
sphinx-copybutton
gym
gymnasium
torch
Expand Down
Binary file added docs/source/_static/data/favicon.ico
Binary file not shown.
Binary file added docs/source/_static/imgs/example_isaac_orbit.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 10 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx_tabs.tabs'
'sphinx_tabs.tabs',
'sphinx_copybutton'
]

# generate links to the documentation of objects in external projects
intersphinx_mapping = {
'python': ('https://docs.python.org/3/', None),
'sphinx': ('https://www.sphinx-doc.org/en/master/', None),
'gym': ('https://www.gymlibrary.dev/', None),
'gymnasium': ('https://gymnasium.farama.org/', None),
"numpy": ("http://docs.scipy.org/doc/numpy/", None),
"torch": ("https://pytorch.org/docs/stable/", None),
}

intersphinx_disabled_domains = ['std']

templates_path = ['_templates']
Expand All @@ -52,6 +58,8 @@

html_logo = '_static/data/skrl-up.png'

html_favicon = "_static/data/favicon.ico"

html_static_path = ['_static']

html_css_files = ['css/s5defs-roles.css',
Expand Down
115 changes: 115 additions & 0 deletions docs/source/examples/isaacorbit/ppo_ant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import torch
import torch.nn as nn

# Import the skrl components to build the RL system
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
from skrl.memories.torch import RandomMemory
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.resources.schedulers.torch import KLAdaptiveRL
from skrl.resources.preprocessors.torch import RunningStandardScaler
from skrl.trainers.torch import SequentialTrainer
from skrl.envs.torch import wrap_env
from skrl.envs.torch import load_isaac_orbit_env
from skrl.utils import set_seed


# set the seed for reproducibility
set_seed(42)


# Define the shared model (stochastic and deterministic models) for the agent using mixins.
class Shared(GaussianMixin, DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
DeterministicMixin.__init__(self, clip_actions=False)

self.net = nn.Sequential(nn.Linear(self.num_observations, 256),
nn.ELU(),
nn.Linear(256, 128),
nn.ELU(),
nn.Linear(128, 64),
nn.ELU())

self.mean_layer = nn.Linear(64, self.num_actions)
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))

self.value_layer = nn.Linear(64, 1)

def act(self, inputs, role):
if role == "policy":
return GaussianMixin.act(self, inputs, role)
elif role == "value":
return DeterministicMixin.act(self, inputs, role)

def compute(self, inputs, role):
if role == "policy":
return torch.tanh(self.mean_layer(self.net(inputs["states"]))), self.log_std_parameter, {}
elif role == "value":
return self.value_layer(self.net(inputs["states"])), {}


# Load and wrap the Omniverse Isaac Gym environment
env = load_isaac_orbit_env(task_name="Isaac-Ant-v0")
env = wrap_env(env)

device = env.device


# Instantiate a RandomMemory as rollout buffer (any memory can be used for this)
memory = RandomMemory(memory_size=16, num_envs=env.num_envs, device=device)


# Instantiate the agent's models (function approximators).
# PPO requires 2 models, visit its documentation for more details
# https://skrl.readthedocs.io/en/latest/modules/skrl.agents.ppo.html#spaces-and-models
models_ppo = {}
models_ppo["policy"] = Shared(env.observation_space, env.action_space, device, clip_actions=True)
models_ppo["value"] = models_ppo["policy"] # same instance: shared model


# Configure and instantiate the agent.
# Only modify some of the default configuration, visit its documentation to see all the options
# https://skrl.readthedocs.io/en/latest/modules/skrl.agents.ppo.html#configuration-and-hyperparameters
cfg_ppo = PPO_DEFAULT_CONFIG.copy()
cfg_ppo["rollouts"] = 16 # memory_size
cfg_ppo["learning_epochs"] = 8
cfg_ppo["mini_batches"] = 4 # 16 * 1024 / 4096
cfg_ppo["discount_factor"] = 0.99
cfg_ppo["lambda"] = 0.95
cfg_ppo["learning_rate"] = 3e-4
cfg_ppo["learning_rate_scheduler"] = KLAdaptiveRL
cfg_ppo["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.008}
cfg_ppo["random_timesteps"] = 0
cfg_ppo["learning_starts"] = 0
cfg_ppo["grad_norm_clip"] = 1.0
cfg_ppo["ratio_clip"] = 0.2
cfg_ppo["value_clip"] = 0.2
cfg_ppo["clip_predicted_values"] = True
cfg_ppo["entropy_loss_scale"] = 0.0
cfg_ppo["value_loss_scale"] = 1.0
cfg_ppo["kl_threshold"] = 0
cfg_ppo["rewards_shaper"] = lambda rewards, timestep, timesteps: rewards * 0.01
cfg_ppo["state_preprocessor"] = RunningStandardScaler
cfg_ppo["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg_ppo["value_preprocessor"] = RunningStandardScaler
cfg_ppo["value_preprocessor_kwargs"] = {"size": 1, "device": device}
# logging to TensorBoard and write checkpoints each 40 and 400 timesteps respectively
cfg_ppo["experiment"]["write_interval"] = 40
cfg_ppo["experiment"]["checkpoint_interval"] = 400

agent = PPO(models=models_ppo,
memory=memory,
cfg=cfg_ppo,
observation_space=env.observation_space,
action_space=env.action_space,
device=device)


# Configure and instantiate the RL trainer
cfg_trainer = {"timesteps": 8000, "headless": True}
trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent)

# start training
trainer.train()
113 changes: 113 additions & 0 deletions docs/source/examples/isaacorbit/ppo_cartpole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import torch
import torch.nn as nn

# Import the skrl components to build the RL system
from skrl.models.torch import Model, GaussianMixin, DeterministicMixin
from skrl.memories.torch import RandomMemory
from skrl.agents.torch.ppo import PPO, PPO_DEFAULT_CONFIG
from skrl.resources.schedulers.torch import KLAdaptiveRL
from skrl.resources.preprocessors.torch import RunningStandardScaler
from skrl.trainers.torch import SequentialTrainer
from skrl.envs.torch import wrap_env
from skrl.envs.torch import load_isaac_orbit_env
from skrl.utils import set_seed


# set the seed for reproducibility
set_seed(42)


# Define the shared model (stochastic and deterministic models) for the agent using mixins.
class Shared(GaussianMixin, DeterministicMixin, Model):
def __init__(self, observation_space, action_space, device, clip_actions=False,
clip_log_std=True, min_log_std=-20, max_log_std=2, reduction="sum"):
Model.__init__(self, observation_space, action_space, device)
GaussianMixin.__init__(self, clip_actions, clip_log_std, min_log_std, max_log_std, reduction)
DeterministicMixin.__init__(self, clip_actions=False)

self.net = nn.Sequential(nn.Linear(self.num_observations, 32),
nn.ELU(),
nn.Linear(32, 32),
nn.ELU())

self.mean_layer = nn.Linear(32, self.num_actions)
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))

self.value_layer = nn.Linear(32, 1)

def act(self, inputs, role):
if role == "policy":
return GaussianMixin.act(self, inputs, role)
elif role == "value":
return DeterministicMixin.act(self, inputs, role)

def compute(self, inputs, role):
if role == "policy":
return torch.tanh(self.mean_layer(self.net(inputs["states"]))), self.log_std_parameter, {}
elif role == "value":
return self.value_layer(self.net(inputs["states"])), {}


# Load and wrap the Isaac Orbit environment
env = load_isaac_orbit_env(task_name="Isaac-Cartpole-v0")
env = wrap_env(env)

device = env.device


# Instantiate a RandomMemory as rollout buffer (any memory can be used for this)
memory = RandomMemory(memory_size=16, num_envs=env.num_envs, device=device)


# Instantiate the agent's models (function approximators).
# PPO requires 2 models, visit its documentation for more details
# https://skrl.readthedocs.io/en/latest/modules/skrl.agents.ppo.html#spaces-and-models
models_ppo = {}
models_ppo["policy"] = Shared(env.observation_space, env.action_space, device, clip_actions=True)
models_ppo["value"] = models_ppo["policy"] # same instance: shared model


# Configure and instantiate the agent.
# Only modify some of the default configuration, visit its documentation to see all the options
# https://skrl.readthedocs.io/en/latest/modules/skrl.agents.ppo.html#configuration-and-hyperparameters
cfg_ppo = PPO_DEFAULT_CONFIG.copy()
cfg_ppo["rollouts"] = 16 # memory_size
cfg_ppo["learning_epochs"] = 8
cfg_ppo["mini_batches"] = 1 # 16 * 512 / 8192
cfg_ppo["discount_factor"] = 0.99
cfg_ppo["lambda"] = 0.95
cfg_ppo["learning_rate"] = 3e-4
cfg_ppo["learning_rate_scheduler"] = KLAdaptiveRL
cfg_ppo["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.008}
cfg_ppo["random_timesteps"] = 0
cfg_ppo["learning_starts"] = 0
cfg_ppo["grad_norm_clip"] = 1.0
cfg_ppo["ratio_clip"] = 0.2
cfg_ppo["value_clip"] = 0.2
cfg_ppo["clip_predicted_values"] = True
cfg_ppo["entropy_loss_scale"] = 0.0
cfg_ppo["value_loss_scale"] = 2.0
cfg_ppo["kl_threshold"] = 0
cfg_ppo["rewards_shaper"] = None
cfg_ppo["state_preprocessor"] = RunningStandardScaler
cfg_ppo["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg_ppo["value_preprocessor"] = RunningStandardScaler
cfg_ppo["value_preprocessor_kwargs"] = {"size": 1, "device": device}
# logging to TensorBoard and write checkpoints each 16 and 80 timesteps respectively
cfg_ppo["experiment"]["write_interval"] = 16
cfg_ppo["experiment"]["checkpoint_interval"] = 80

agent = PPO(models=models_ppo,
memory=memory,
cfg=cfg_ppo,
observation_space=env.observation_space,
action_space=env.action_space,
device=device)


# Configure and instantiate the RL trainer
cfg_trainer = {"timesteps": 1600, "headless": True}
trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent)

# start training
trainer.train()
Loading

0 comments on commit 930b8d7

Please sign in to comment.