Skip to content

Commit

Permalink
Updates cartpole environment to new framework (#241)
Browse files Browse the repository at this point in the history
# Description

This MR provides the Cartpole environment reimplemented with the new
RLTaskEnv.

Replaces #177

## Type of change

- Bug fix (non-breaking change which fixes an issue)

## Screenshot

Training results with SB-3 (orange), RSL-RL (purple) and RL-Games
(black)


![sb3](https://github.com/isaac-orbit/orbit/assets/12863862/c4b7a50e-099a-4391-8142-e2924e633429)

![rslr-l](https://github.com/isaac-orbit/orbit/assets/12863862/aa223a46-d4a8-454a-b26d-b6359e173bbc)

![rlg](https://github.com/isaac-orbit/orbit/assets/12863862/2f20d837-40af-4fa2-ac58-7f520c4a3517)


## Checklist

- [x] I have run the [`pre-commit` checks](https://pre-commit.com/) with
`./orbit.sh --format`
- [ ] I have made corresponding changes to the documentation
- [x] My changes generate no new warnings
- [ ] I have added tests that prove my fix is effective or that my
feature works
- [ ] I have updated the changelog and the corresponding version in the
extension's `config/extension.toml` file
- [x] I have added my name to the `CONTRIBUTORS.md` or my name already
exists there
  • Loading branch information
Mayankm96 authored Nov 10, 2023
1 parent 5d44141 commit 541dd06
Show file tree
Hide file tree
Showing 18 changed files with 390 additions and 354 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,14 +169,14 @@ def push_by_setting_velocity(
asset.write_root_velocity_to_sim(vel_w, env_ids=env_ids)


def reset_root_state(
def reset_root_state_uniform(
env: BaseEnv,
env_ids: torch.Tensor,
pose_range: dict[str, tuple[float, float]],
velocity_range: dict[str, tuple[float, float]],
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
):
"""Reset the asset root state to a random position and velocity within the given ranges.
"""Reset the asset root state to a random position and velocity uniformly within the given ranges.
This function randomizes the root position and velocity of the asset.
Expand Down Expand Up @@ -245,6 +245,32 @@ def reset_joints_by_scale(
asset.write_joint_state_to_sim(joint_pos, joint_vel, env_ids=env_ids)


def reset_joints_by_offset(
env: BaseEnv,
env_ids: torch.Tensor,
position_range: tuple[float, float],
velocity_range: tuple[float, float],
asset_cfg: SceneEntityCfg = SceneEntityCfg("robot"),
):
"""Reset the robot joints with offsets around the default position and velocity by the given ranges.
This function samples random values from the given ranges and biases the default joint positions and velocities
by these values. The biased values are then set into the physics simulation.
"""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]

# get default joint state
joint_pos = asset.data.default_joint_pos[env_ids].clone()
joint_vel = asset.data.default_joint_vel[env_ids].clone()
# bias these values randomly
joint_pos += sample_uniform(*position_range, joint_pos.shape, joint_pos.device)
joint_vel += sample_uniform(*velocity_range, joint_vel.shape, joint_vel.device)

# set into the physics simulation
asset.write_joint_state_to_sim(joint_pos, joint_vel, env_ids=env_ids)


def reset_scene_to_default(env: BaseEnv, env_ids: torch.Tensor):
"""Reset the scene to the default state specified in the scene configuration."""
# root states
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
"""


def alive_bonus(env: RLTaskEnv) -> torch.Tensor:
"""Reward for being alive."""
return ~env.reset_buf * 1.0


def termination_penalty(env: RLTaskEnv) -> torch.Tensor:
"""Penalize terminated episodes that don't correspond to episodic timeouts."""
return env.reset_buf * (~env.termination_manager.time_outs)
Expand Down Expand Up @@ -86,14 +91,21 @@ def body_lin_acc_l2(env: RLTaskEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("


def joint_torques_l2(env: RLTaskEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize torques applied on the articulation using L2-kernel."""
"""Penalize joint torques applied on the articulation using L2-kernel."""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
return torch.sum(torch.square(asset.data.applied_torque), dim=1)


def joint_vel_l1(env: RLTaskEnv, asset_cfg: SceneEntityCfg) -> torch.Tensor:
"""Penalize joint velocities on the articulation using an L1-kernel."""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
return torch.sum(torch.abs(asset.data.joint_vel[:, asset_cfg.joint_ids]), dim=1)


def joint_vel_l2(env: RLTaskEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Penalize joint velocities on the articulation."""
"""Penalize joint velocities on the articulation using L2-kernel."""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
return torch.sum(torch.square(asset.data.joint_vel), dim=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,30 @@ def joint_pos_limit(env: RLTaskEnv, asset_cfg: SceneEntityCfg = SceneEntityCfg("
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
# compute any violations
out_of_upper_limits = torch.any(asset.data.joint_pos > asset.data.soft_joint_pos_limits[:, 0], dim=1)
out_of_lower_limits = torch.any(asset.data.joint_pos < asset.data.soft_joint_pos_limits[:, 1], dim=1)
out_of_upper_limits = torch.any(asset.data.joint_pos > asset.data.soft_joint_pos_limits[..., 1], dim=1)
out_of_lower_limits = torch.any(asset.data.joint_pos < asset.data.soft_joint_pos_limits[..., 0], dim=1)
return torch.logical_or(out_of_upper_limits, out_of_lower_limits)


def joint_velocity_limit(
env: RLTaskEnv, max_velocity, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
def joint_pos_manual_limit(
env: RLTaskEnv, bounds: tuple[float, float], asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")
) -> torch.Tensor:
"""Terminate when the asset's joint positions are outside of the configured bounds.
Note:
This function is similar to :func:`joint_pos_limit` but allows the user to specify the bounds manually.
"""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
if asset_cfg.joint_ids is None:
asset_cfg.joint_ids = slice(None)
# compute any violations
out_of_upper_limits = torch.any(asset.data.joint_pos[:, asset_cfg.joint_ids] > bounds[1], dim=1)
out_of_lower_limits = torch.any(asset.data.joint_pos[:, asset_cfg.joint_ids] < bounds[0], dim=1)
return torch.logical_or(out_of_upper_limits, out_of_lower_limits)


def joint_vel_limit(env: RLTaskEnv, max_velocity, asset_cfg: SceneEntityCfg = SceneEntityCfg("robot")) -> torch.Tensor:
"""Terminate when the asset's joint velocities are outside of the soft joint limits."""
# extract the used quantities (to enable type-hinting)
asset: Articulation = env.scene[asset_cfg.name]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@
from .utils import import_packages

# The blacklist is used to prevent importing configs from sub-packages
_BLACKLIST_PKGS = ["locomotion.velocity.config.anymal_d", "classic", "manipulation", "utils"]
_BLACKLIST_PKGS = ["locomotion.velocity.config.anymal_d", "manipulation", "utils", "classic.ant", "classic.humanoid"]
# Import all configs in this package
import_packages(__name__, _BLACKLIST_PKGS)
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,3 @@
Reference:
https://github.com/openai/gym/tree/master/gym/envs/mujoco
"""

from .ant import AntEnv
from .cartpole import CartpoleEnv
from .humanoid import HumanoidEnv

__all__ = ["CartpoleEnv", "AntEnv", "HumanoidEnv"]
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@
Ant locomotion environment (similar to OpenAI Gym Ant-v2).
"""

import gymnasium as gym
# import gymnasium as gym

from . import agents
# from . import agents

##
# Register Gym environments.
##

gym.register(
id="Isaac-Ant-v0",
entry_point="omni.isaac.orbit.envs:RLTaskEnv",
kwargs={
"env_cfg_entry_point": f"{__name__}:ant_env_cfg.yaml",
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml",
"skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml",
"sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml",
},
)
# gym.register(
# id="Isaac-Ant-v0",
# entry_point="omni.isaac.orbit.envs:RLTaskEnv",
# kwargs={
# "env_cfg_entry_point": f"{__name__}:ant_env_cfg.yaml",
# "rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml",
# "skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml",
# "sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml",
# },
# )
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import gymnasium as gym

from . import agents
from .cartpole_env_cfg import CartpoleEnvCfg

##
# Register Gym environments.
Expand All @@ -18,10 +19,11 @@
gym.register(
id="Isaac-Cartpole-v0",
entry_point="omni.isaac.orbit.envs:RLTaskEnv",
disable_env_checker=True,
kwargs={
"env_cfg_entry_point": f"{__name__}:ant_env_cfg.yaml",
"env_cfg_entry_point": CartpoleEnvCfg,
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml",
"rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:CARTPOLE_RSL_RL_PPO_CFG",
"rsl_rl_cfg_entry_point": agents.rsl_rl_ppo_cfg.CartpolePPORunnerCfg,
"skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml",
"sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml",
},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) 2022-2023, The ORBIT Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

from . import rsl_rl_ppo_cfg # noqa: F401, F403
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ params:
multi_gpu: False
ppo: True
mixed_precision: False
normalize_input: True
normalize_value: True
num_actors: -1
normalize_input: False
normalize_value: False
num_actors: -1 # configured from the script (based on num_envs)
reward_shaper:
scale_value: 1.0
normalize_advantage: False
Expand All @@ -62,7 +62,7 @@ params:
lr_schedule: adaptive
kl_threshold: 0.008
score_to_win: 20000
max_epochs: 450
max_epochs: 150
save_best_after: 50
save_frequency: 25
grad_norm: 1.0
Expand All @@ -74,5 +74,5 @@ params:
mini_epochs: 8
critic_coef: 4
clip_value: True
seq_len: 4
bounds_loss_coef: 0
seq_length: 4
bounds_loss_coef: 0.0001
Original file line number Diff line number Diff line change
Expand Up @@ -3,40 +3,39 @@
#
# SPDX-License-Identifier: BSD-3-Clause

from omni.isaac.orbit.utils import configclass

from omni.isaac.orbit_tasks.utils.wrappers.rsl_rl import (
RslRlOnPolicyRunnerCfg,
RslRlPpoActorCriticCfg,
RslRlPpoAlgorithmCfg,
)

CARTPOLE_RSL_RL_PPO_CFG = RslRlOnPolicyRunnerCfg(
num_steps_per_env=16,
max_iterations=500,
save_interval=50,
experiment_name="cartpole",
run_name="",
resume=False,
load_run=-1,
load_checkpoint=-1,
empirical_normalization=False,
policy=RslRlPpoActorCriticCfg(

@configclass
class CartpolePPORunnerCfg(RslRlOnPolicyRunnerCfg):
num_steps_per_env = 16
max_iterations = 150
save_interval = 50
experiment_name = "cartpole"
empirical_normalization = False
policy = RslRlPpoActorCriticCfg(
init_noise_std=1.0,
actor_hidden_dims=[256, 128, 64],
critic_hidden_dims=[256, 128, 64],
actor_hidden_dims=[32, 32],
critic_hidden_dims=[32, 32],
activation="elu",
),
algorithm=RslRlPpoAlgorithmCfg(
)
algorithm = RslRlPpoAlgorithmCfg(
value_loss_coef=1.0,
use_clipped_value_loss=True,
clip_param=0.2,
entropy_coef=0.01,
num_learning_epochs=8,
entropy_coef=0.005,
num_learning_epochs=5,
num_mini_batches=4,
learning_rate=1.0e-3,
schedule="adaptive",
gamma=0.99,
lam=0.95,
desired_kl=0.01,
max_grad_norm=1.0,
),
)
)
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
# Reference: https://github.com/DLR-RM/rl-baselines3-zoo/blob/master/hyperparams/ppo.yml#L32
seed: 42

# 512×500×16
n_timesteps: !!float 2e6
n_timesteps: !!float 1e6
policy: 'MlpPolicy'
n_steps: 16
batch_size: 8192
batch_size: 4096
gae_lambda: 0.95
gamma: 0.99
n_epochs: 8
ent_coef: 0.0
n_epochs: 20
ent_coef: 0.01
learning_rate: !!float 3e-4
clip_range: 0.2
clip_range: !!float 0.2
policy_kwargs: "dict(
log_std_init=-2,
ortho_init=False,
activation_fn=nn.ELU,
net_arch=[32, 32]
net_arch=[32, 32],
squash_output=False,
)"


target_kl: 0.008
vf_coef: 1.0
max_grad_norm: 1.0

# Uses VecNormalize class to normalize obs
normalize_input: True
# Uses VecNormalize class to normalize rew
normalize_value: True
clip_obs: 5

This file was deleted.

Loading

0 comments on commit 541dd06

Please sign in to comment.