Skip to content

Commit

Permalink
Update hyperparameters for the Isaac Orbit Isaac-Lift-Franka-v0 envir…
Browse files Browse the repository at this point in the history
…onment
  • Loading branch information
Toni-SM committed Aug 11, 2023
1 parent 4b5b98f commit 0000bdf
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 26 deletions.
27 changes: 14 additions & 13 deletions docs/source/examples/isaacorbit/jax_lift_franka_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from skrl.memories.jax import RandomMemory
from skrl.models.jax import DeterministicMixin, GaussianMixin, Model
from skrl.resources.preprocessors.jax import RunningStandardScaler
from skrl.resources.schedulers.jax import KLAdaptiveRL
from skrl.resources.schedulers.jax import KLAdaptiveLR
from skrl.trainers.jax import SequentialTrainer
from skrl.utils import set_seed

Expand Down Expand Up @@ -52,7 +52,7 @@ def __call__(self, inputs, role):
x = nn.elu(nn.Dense(128)(x))
x = nn.elu(nn.Dense(64)(x))
x = nn.Dense(self.num_actions)(x)
log_std = self.param("log_std", lambda _: jnp.zeros(self.num_actions))
log_std = self.param("log_std", lambda _: jnp.ones(self.num_actions))
return x, log_std, {}

class Value(DeterministicMixin, Model):
Expand Down Expand Up @@ -80,7 +80,7 @@ def __call__(self, inputs, role):


# instantiate a memory as rollout buffer (any memory can be used for this)
memory = RandomMemory(memory_size=32, num_envs=env.num_envs, device=device)
memory = RandomMemory(memory_size=96, num_envs=env.num_envs, device=device)


# instantiate the agent's models (function approximators).
Expand All @@ -98,31 +98,32 @@ def __call__(self, inputs, role):
# configure and instantiate the agent (visit its documentation to see all the options)
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html#configuration-and-hyperparameters
cfg = PPO_DEFAULT_CONFIG.copy()
cfg["rollouts"] = 32 # memory_size
cfg["rollouts"] = 96 # memory_size
cfg["learning_epochs"] = 5
cfg["mini_batches"] = 16 # 32 * 1024 / 2048
cfg["mini_batches"] = 4 # 96 * 4096 / 98304
cfg["discount_factor"] = 0.99
cfg["lambda"] = 0.95
cfg["learning_rate"] = 5e-4
cfg["learning_rate_scheduler"] = KLAdaptiveRL
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.008}
cfg["learning_rate"] = 1e-3
cfg["learning_rate_scheduler"] = KLAdaptiveLR
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01, "min_lr": 1e-5}
cfg["random_timesteps"] = 0
cfg["learning_starts"] = 0
cfg["grad_norm_clip"] = 1.0
cfg["ratio_clip"] = 0.2
cfg["value_clip"] = 0.2
cfg["clip_predicted_values"] = True
cfg["entropy_loss_scale"] = 0.0
cfg["value_loss_scale"] = 2.0
cfg["entropy_loss_scale"] = 0.01
cfg["value_loss_scale"] = 1.0
cfg["kl_threshold"] = 0
cfg["rewards_shaper"] = None
cfg["time_limit_bootstrap"] = True
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
cfg["value_preprocessor_kwargs"] = {"size": 1, "device": device}
# logging to TensorBoard and write checkpoints (in timesteps)
cfg["experiment"]["write_interval"] = 120
cfg["experiment"]["checkpoint_interval"] = 1200
cfg["experiment"]["write_interval"] = 800
cfg["experiment"]["checkpoint_interval"] = 8000
cfg["experiment"]["directory"] = "runs/jax/Isaac-Lift-Franka-v0"

agent = PPO(models=models,
Expand All @@ -134,7 +135,7 @@ def __call__(self, inputs, role):


# configure and instantiate the RL trainer
cfg_trainer = {"timesteps": 24000, "headless": True}
cfg_trainer = {"timesteps": 67200, "headless": True}
trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent)

# start training
Expand Down
27 changes: 14 additions & 13 deletions docs/source/examples/isaacorbit/torch_lift_franka_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from skrl.memories.torch import RandomMemory
from skrl.models.torch import DeterministicMixin, GaussianMixin, Model
from skrl.resources.preprocessors.torch import RunningStandardScaler
from skrl.resources.schedulers.torch import KLAdaptiveRL
from skrl.resources.schedulers.torch import KLAdaptiveLR
from skrl.trainers.torch import SequentialTrainer
from skrl.utils import set_seed

Expand All @@ -33,7 +33,7 @@ def __init__(self, observation_space, action_space, device, clip_actions=False,
nn.ELU())

self.mean_layer = nn.Linear(64, self.num_actions)
self.log_std_parameter = nn.Parameter(torch.zeros(self.num_actions))
self.log_std_parameter = nn.Parameter(torch.ones(self.num_actions))

self.value_layer = nn.Linear(64, 1)

Expand All @@ -58,7 +58,7 @@ def compute(self, inputs, role):


# instantiate a memory as rollout buffer (any memory can be used for this)
memory = RandomMemory(memory_size=32, num_envs=env.num_envs, device=device)
memory = RandomMemory(memory_size=96, num_envs=env.num_envs, device=device)


# instantiate the agent's models (function approximators).
Expand All @@ -72,31 +72,32 @@ def compute(self, inputs, role):
# configure and instantiate the agent (visit its documentation to see all the options)
# https://skrl.readthedocs.io/en/latest/api/agents/ppo.html#configuration-and-hyperparameters
cfg = PPO_DEFAULT_CONFIG.copy()
cfg["rollouts"] = 32 # memory_size
cfg["rollouts"] = 96 # memory_size
cfg["learning_epochs"] = 5
cfg["mini_batches"] = 16 # 32 * 1024 / 2048
cfg["mini_batches"] = 4 # 96 * 4096 / 98304
cfg["discount_factor"] = 0.99
cfg["lambda"] = 0.95
cfg["learning_rate"] = 5e-4
cfg["learning_rate_scheduler"] = KLAdaptiveRL
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.008}
cfg["learning_rate"] = 1e-3
cfg["learning_rate_scheduler"] = KLAdaptiveLR
cfg["learning_rate_scheduler_kwargs"] = {"kl_threshold": 0.01, "min_lr": 1e-5}
cfg["random_timesteps"] = 0
cfg["learning_starts"] = 0
cfg["grad_norm_clip"] = 1.0
cfg["ratio_clip"] = 0.2
cfg["value_clip"] = 0.2
cfg["clip_predicted_values"] = True
cfg["entropy_loss_scale"] = 0.0
cfg["value_loss_scale"] = 2.0
cfg["entropy_loss_scale"] = 0.01
cfg["value_loss_scale"] = 1.0
cfg["kl_threshold"] = 0
cfg["rewards_shaper"] = None
cfg["time_limit_bootstrap"] = True
cfg["state_preprocessor"] = RunningStandardScaler
cfg["state_preprocessor_kwargs"] = {"size": env.observation_space, "device": device}
cfg["value_preprocessor"] = RunningStandardScaler
cfg["value_preprocessor_kwargs"] = {"size": 1, "device": device}
# logging to TensorBoard and write checkpoints (in timesteps)
cfg["experiment"]["write_interval"] = 120
cfg["experiment"]["checkpoint_interval"] = 1200
cfg["experiment"]["write_interval"] = 800
cfg["experiment"]["checkpoint_interval"] = 8000
cfg["experiment"]["directory"] = "runs/torch/Isaac-Lift-Franka-v0"

agent = PPO(models=models,
Expand All @@ -108,7 +109,7 @@ def compute(self, inputs, role):


# configure and instantiate the RL trainer
cfg_trainer = {"timesteps": 24000, "headless": True}
cfg_trainer = {"timesteps": 67200, "headless": True}
trainer = SequentialTrainer(cfg=cfg_trainer, env=env, agents=agent)

# start training
Expand Down

0 comments on commit 0000bdf

Please sign in to comment.