Skip to content

Commit

Permalink
Fixes RSL-RL ONNX exporter for empirical normalization (#78)
Browse files Browse the repository at this point in the history
The current onnx exporter does not export the empirical normalization
layer. This MR adds the empirical normalization exporting to the JIT
and ONNX exporters for RSL-RL.

- Bug fix (non-breaking change which fixes an issue)

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [x] I have run all the tests with `./orbit.sh --test` and they pass
(some did timeout)
- [x] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [ ] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there

---------

Co-authored-by: Mayank Mittal <[email protected]>
  • Loading branch information
Nemantor and Mayankm96 committed May 31, 2024
1 parent 90f6fb1 commit 7af7aa8
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 12 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[package]

# Note: Semantic Versioning is used: https://semver.org/
version = "0.6.1"
version = "0.6.2"

# Description
title = "ORBIT Environments"
Expand Down
12 changes: 12 additions & 0 deletions source/extensions/omni.isaac.orbit_tasks/docs/CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
Changelog
---------

0.6.2 (2024-05-31)
~~~~~~~~~~~~~~~~~~

Added
^^^^^

* Added exporting of empirical normalization layer to ONNX and JIT when exporting the model using
:meth:`omni.isaac.orbit.actuators.ActuatorNetMLP.export` method. Previously, the normalization layer
was not exported to the ONNX and JIT models. This caused the exported model to not work properly
when used for inference.


0.6.1 (2024-04-16)
~~~~~~~~~~~~~~~~~~

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,37 @@
import torch


def export_policy_as_jit(actor_critic: object, path: str, filename="policy.pt"):
def export_policy_as_jit(actor_critic: object, normalizer: object | None, path: str, filename="policy.pt"):
"""Export policy into a Torch JIT file.
Args:
actor_critic: The actor-critic torch module.
normalizer: The empirical normalizer module. If None, Identity is used.
path: The path to the saving directory.
filename: The name of exported JIT file. Defaults to "policy.pt".
Reference:
https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/utils/helpers.py#L180
"""
policy_exporter = _TorchPolicyExporter(actor_critic)
policy_exporter = _TorchPolicyExporter(actor_critic, normalizer)
policy_exporter.export(path, filename)


def export_policy_as_onnx(actor_critic: object, path: str, filename="policy.onnx", verbose=False):
def export_policy_as_onnx(
actor_critic: object, normalizer: object | None, path: str, filename="policy.onnx", verbose=False
):
"""Export policy into a Torch ONNX file.
Args:
actor_critic: The actor-critic torch module.
normalizer: The empirical normalizer module. If None, Identity is used.
path: The path to the saving directory.
filename: The name of exported JIT file. Defaults to "policy.pt".
filename: The name of exported ONNX file. Defaults to "policy.onnx".
verbose: Whether to print the model summary. Defaults to False.
"""
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
policy_exporter = _OnnxPolicyExporter(actor_critic, verbose)
policy_exporter = _OnnxPolicyExporter(actor_critic, normalizer, verbose)
policy_exporter.export(path, filename)


Expand All @@ -50,7 +54,7 @@ class _TorchPolicyExporter(torch.nn.Module):
https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/utils/helpers.py#L193
"""

def __init__(self, actor_critic):
def __init__(self, actor_critic, normalizer=None):
super().__init__()
self.actor = copy.deepcopy(actor_critic.actor)
self.is_recurrent = actor_critic.is_recurrent
Expand All @@ -61,16 +65,22 @@ def __init__(self, actor_critic):
self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size))
self.forward = self.forward_lstm
self.reset = self.reset_memory
# copy normalizer if exists
if normalizer:
self.normalizer = copy.deepcopy(normalizer)
else:
self.normalizer = torch.nn.Identity()

def forward_lstm(self, x):
x = self.normalizer(x)
x, (h, c) = self.rnn(x.unsqueeze(0), (self.hidden_state, self.cell_state))
self.hidden_state[:] = h
self.cell_state[:] = c
x = x.squeeze(0)
return self.actor(x)

def forward(self, x):
return self.actor(x)
return self.actor(self.normalizer(x))

@torch.jit.export
def reset(self):
Expand All @@ -91,7 +101,7 @@ def export(self, path, filename):
class _OnnxPolicyExporter(torch.nn.Module):
"""Exporter of actor-critic into ONNX file."""

def __init__(self, actor_critic, verbose=False):
def __init__(self, actor_critic, normalizer=None, verbose=False):
super().__init__()
self.verbose = verbose
self.actor = copy.deepcopy(actor_critic.actor)
Expand All @@ -100,14 +110,20 @@ def __init__(self, actor_critic, verbose=False):
self.rnn = copy.deepcopy(actor_critic.memory_a.rnn)
self.rnn.cpu()
self.forward = self.forward_lstm
# copy normalizer if exists
if normalizer:
self.normalizer = copy.deepcopy(normalizer)
else:
self.normalizer = torch.nn.Identity()

def forward_lstm(self, x_in, h_in, c_in):
x_in = self.normalizer(x_in)
x, (h, c) = self.rnn(x_in.unsqueeze(0), (h_in, c_in))
x = x.squeeze(0)
return self.actor(x), h, c

def forward(self, x):
return self.actor(x)
return self.actor(self.normalizer(x))

def export(self, path, filename):
self.to("cpu")
Expand Down
10 changes: 8 additions & 2 deletions source/standalone/workflows/rsl_rl/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from omni.isaac.orbit_tasks.utils.wrappers.rsl_rl import (
RslRlOnPolicyRunnerCfg,
RslRlVecEnvWrapper,
export_policy_as_jit,
export_policy_as_onnx,
)

Expand Down Expand Up @@ -78,9 +79,14 @@ def main():
# obtain the trained policy for inference
policy = ppo_runner.get_inference_policy(device=env.unwrapped.device)

# export policy to onnx
# export policy to onnx/jit
export_model_dir = os.path.join(os.path.dirname(resume_path), "exported")
export_policy_as_onnx(ppo_runner.alg.actor_critic, export_model_dir, filename="policy.onnx")
export_policy_as_jit(
ppo_runner.alg.actor_critic, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt"
)
export_policy_as_onnx(
ppo_runner.alg.actor_critic, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx"
)

# reset environment
obs, _ = env.get_observations()
Expand Down

0 comments on commit 7af7aa8

Please sign in to comment.