diff --git a/source/extensions/omni.isaac.orbit_tasks/config/extension.toml b/source/extensions/omni.isaac.orbit_tasks/config/extension.toml index e4ad9f7b1a..6a2916d257 100644 --- a/source/extensions/omni.isaac.orbit_tasks/config/extension.toml +++ b/source/extensions/omni.isaac.orbit_tasks/config/extension.toml @@ -1,7 +1,7 @@ [package] # Note: Semantic Versioning is used: https://semver.org/ -version = "0.6.1" +version = "0.6.2" # Description title = "ORBIT Environments" diff --git a/source/extensions/omni.isaac.orbit_tasks/docs/CHANGELOG.rst b/source/extensions/omni.isaac.orbit_tasks/docs/CHANGELOG.rst index 98b43e40be..a58f09fb80 100644 --- a/source/extensions/omni.isaac.orbit_tasks/docs/CHANGELOG.rst +++ b/source/extensions/omni.isaac.orbit_tasks/docs/CHANGELOG.rst @@ -1,6 +1,18 @@ Changelog --------- +0.6.2 (2024-05-31) +~~~~~~~~~~~~~~~~~~ + +Added +^^^^^ + +* Added exporting of empirical normalization layer to ONNX and JIT when exporting the model using + :meth:`omni.isaac.orbit.actuators.ActuatorNetMLP.export` method. Previously, the normalization layer + was not exported to the ONNX and JIT models. This caused the exported model to not work properly + when used for inference. + + 0.6.1 (2024-04-16) ~~~~~~~~~~~~~~~~~~ diff --git a/source/extensions/omni.isaac.orbit_tasks/omni/isaac/orbit_tasks/utils/wrappers/rsl_rl/exporter.py b/source/extensions/omni.isaac.orbit_tasks/omni/isaac/orbit_tasks/utils/wrappers/rsl_rl/exporter.py index 7da7197493..8ef5045ff7 100644 --- a/source/extensions/omni.isaac.orbit_tasks/omni/isaac/orbit_tasks/utils/wrappers/rsl_rl/exporter.py +++ b/source/extensions/omni.isaac.orbit_tasks/omni/isaac/orbit_tasks/utils/wrappers/rsl_rl/exporter.py @@ -8,33 +8,37 @@ import torch -def export_policy_as_jit(actor_critic: object, path: str, filename="policy.pt"): +def export_policy_as_jit(actor_critic: object, normalizer: object | None, path: str, filename="policy.pt"): """Export policy into a Torch JIT file. Args: actor_critic: The actor-critic torch module. + normalizer: The empirical normalizer module. If None, Identity is used. path: The path to the saving directory. filename: The name of exported JIT file. Defaults to "policy.pt". Reference: https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/utils/helpers.py#L180 """ - policy_exporter = _TorchPolicyExporter(actor_critic) + policy_exporter = _TorchPolicyExporter(actor_critic, normalizer) policy_exporter.export(path, filename) -def export_policy_as_onnx(actor_critic: object, path: str, filename="policy.onnx", verbose=False): +def export_policy_as_onnx( + actor_critic: object, normalizer: object | None, path: str, filename="policy.onnx", verbose=False +): """Export policy into a Torch ONNX file. Args: actor_critic: The actor-critic torch module. + normalizer: The empirical normalizer module. If None, Identity is used. path: The path to the saving directory. - filename: The name of exported JIT file. Defaults to "policy.pt". + filename: The name of exported ONNX file. Defaults to "policy.onnx". verbose: Whether to print the model summary. Defaults to False. """ if not os.path.exists(path): os.makedirs(path, exist_ok=True) - policy_exporter = _OnnxPolicyExporter(actor_critic, verbose) + policy_exporter = _OnnxPolicyExporter(actor_critic, normalizer, verbose) policy_exporter.export(path, filename) @@ -50,7 +54,7 @@ class _TorchPolicyExporter(torch.nn.Module): https://github.com/leggedrobotics/legged_gym/blob/master/legged_gym/utils/helpers.py#L193 """ - def __init__(self, actor_critic): + def __init__(self, actor_critic, normalizer=None): super().__init__() self.actor = copy.deepcopy(actor_critic.actor) self.is_recurrent = actor_critic.is_recurrent @@ -61,8 +65,14 @@ def __init__(self, actor_critic): self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)) self.forward = self.forward_lstm self.reset = self.reset_memory + # copy normalizer if exists + if normalizer: + self.normalizer = copy.deepcopy(normalizer) + else: + self.normalizer = torch.nn.Identity() def forward_lstm(self, x): + x = self.normalizer(x) x, (h, c) = self.rnn(x.unsqueeze(0), (self.hidden_state, self.cell_state)) self.hidden_state[:] = h self.cell_state[:] = c @@ -70,7 +80,7 @@ def forward_lstm(self, x): return self.actor(x) def forward(self, x): - return self.actor(x) + return self.actor(self.normalizer(x)) @torch.jit.export def reset(self): @@ -91,7 +101,7 @@ def export(self, path, filename): class _OnnxPolicyExporter(torch.nn.Module): """Exporter of actor-critic into ONNX file.""" - def __init__(self, actor_critic, verbose=False): + def __init__(self, actor_critic, normalizer=None, verbose=False): super().__init__() self.verbose = verbose self.actor = copy.deepcopy(actor_critic.actor) @@ -100,14 +110,20 @@ def __init__(self, actor_critic, verbose=False): self.rnn = copy.deepcopy(actor_critic.memory_a.rnn) self.rnn.cpu() self.forward = self.forward_lstm + # copy normalizer if exists + if normalizer: + self.normalizer = copy.deepcopy(normalizer) + else: + self.normalizer = torch.nn.Identity() def forward_lstm(self, x_in, h_in, c_in): + x_in = self.normalizer(x_in) x, (h, c) = self.rnn(x_in.unsqueeze(0), (h_in, c_in)) x = x.squeeze(0) return self.actor(x), h, c def forward(self, x): - return self.actor(x) + return self.actor(self.normalizer(x)) def export(self, path, filename): self.to("cpu") diff --git a/source/standalone/workflows/rsl_rl/play.py b/source/standalone/workflows/rsl_rl/play.py index f03df2a703..20384b5d41 100644 --- a/source/standalone/workflows/rsl_rl/play.py +++ b/source/standalone/workflows/rsl_rl/play.py @@ -46,6 +46,7 @@ from omni.isaac.orbit_tasks.utils.wrappers.rsl_rl import ( RslRlOnPolicyRunnerCfg, RslRlVecEnvWrapper, + export_policy_as_jit, export_policy_as_onnx, ) @@ -78,9 +79,14 @@ def main(): # obtain the trained policy for inference policy = ppo_runner.get_inference_policy(device=env.unwrapped.device) - # export policy to onnx + # export policy to onnx/jit export_model_dir = os.path.join(os.path.dirname(resume_path), "exported") - export_policy_as_onnx(ppo_runner.alg.actor_critic, export_model_dir, filename="policy.onnx") + export_policy_as_jit( + ppo_runner.alg.actor_critic, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.pt" + ) + export_policy_as_onnx( + ppo_runner.alg.actor_critic, ppo_runner.obs_normalizer, path=export_model_dir, filename="policy.onnx" + ) # reset environment obs, _ = env.get_observations()