Skip to content

Commit

Permalink
Fix multi-agent learning rate scheduler in jax
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jan 15, 2025
1 parent 16224b0 commit 46526b5
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 54 deletions.
6 changes: 3 additions & 3 deletions skrl/multi_agents/jax/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
pass

self.tracking_data = collections.defaultdict(list)
self.write_interval = self.cfg.get("experiment", {}).get("write_interval", 1000)
self.write_interval = self.cfg.get("experiment", {}).get("write_interval", "auto")

self._track_rewards = collections.deque(maxlen=100)
self._track_timesteps = collections.deque(maxlen=100)
Expand All @@ -77,9 +77,9 @@ def __init__(

# checkpoint
self.checkpoint_modules = {uid: {} for uid in self.possible_agents}
self.checkpoint_interval = self.cfg.get("experiment", {}).get("checkpoint_interval", 1000)
self.checkpoint_interval = self.cfg.get("experiment", {}).get("checkpoint_interval", "auto")
self.checkpoint_store_separately = self.cfg.get("experiment", {}).get("store_separately", False)
self.checkpoint_best_modules = {"timestep": 0, "reward": -(2**31), "saved": True, "modules": {}}
self.checkpoint_best_modules = {"timestep": 0, "reward": -(2**31), "saved": False, "modules": {}}

# experiment directory
directory = self.cfg.get("experiment", {}).get("directory", "")
Expand Down
44 changes: 22 additions & 22 deletions skrl/multi_agents/jax/ippo/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"lambda": 0.95, # TD(lambda) coefficient (lam) for computing returns and advantages

"learning_rate": 1e-3, # learning rate
"learning_rate_scheduler": None, # learning rate scheduler class (see torch.optim.lr_scheduler)
"learning_rate_scheduler": None, # learning rate scheduler function (see optax.schedules)
"learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})

"state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors)
Expand All @@ -53,9 +53,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
"write_interval": 250, # TensorBoard writing interval (timesteps)
"write_interval": "auto", # TensorBoard writing interval (timesteps)

"checkpoint_interval": 1000, # interval for checkpoints (timesteps)
"checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately

"wandb": False, # whether to use Weights & Biases
Expand Down Expand Up @@ -286,7 +286,6 @@ def __init__(

self._learning_rate = self._as_dict(self.cfg["learning_rate"])
self._learning_rate_scheduler = self._as_dict(self.cfg["learning_rate_scheduler"])
self._learning_rate_scheduler_kwargs = self._as_dict(self.cfg["learning_rate_scheduler_kwargs"])

self._state_preprocessor = self._as_dict(self.cfg["state_preprocessor"])
self._state_preprocessor_kwargs = self._as_dict(self.cfg["state_preprocessor_kwargs"])
Expand All @@ -313,24 +312,23 @@ def __init__(

if policy is not None and value is not None:
# scheduler
scale = True
self.schedulers[uid] = None
if self._learning_rate_scheduler[uid] is not None:
if self._learning_rate_scheduler[uid] == KLAdaptiveLR:
scale = False
self.schedulers[uid] = self._learning_rate_scheduler[uid](
self._learning_rate[uid], **self._learning_rate_scheduler_kwargs[uid]
)
else:
self._learning_rate[uid] = self._learning_rate_scheduler[uid](
self._learning_rate[uid], **self._learning_rate_scheduler_kwargs[uid]
)
if self._learning_rate_scheduler[uid]:
self.schedulers[uid] = self._learning_rate_scheduler[uid](
**self._as_dict(self.cfg["learning_rate_scheduler_kwargs"])[uid]
)
# optimizer
self.policy_optimizer[uid] = Adam(
model=policy, lr=self._learning_rate[uid], grad_norm_clip=self._grad_norm_clip[uid], scale=scale
model=policy,
lr=self._learning_rate[uid],
grad_norm_clip=self._grad_norm_clip[uid],
scale=not self._learning_rate_scheduler[uid],
)
self.value_optimizer[uid] = Adam(
model=value, lr=self._learning_rate[uid], grad_norm_clip=self._grad_norm_clip[uid], scale=scale
model=value,
lr=self._learning_rate[uid],
grad_norm_clip=self._grad_norm_clip[uid],
scale=not self._learning_rate_scheduler[uid],
)

self.checkpoint_modules[uid]["policy_optimizer"] = self.policy_optimizer[uid]
Expand Down Expand Up @@ -606,7 +604,7 @@ def _update(self, timestep: int, timesteps: int) -> None:
if config.jax.is_distributed:
grad = policy.reduce_parameters(grad)
self.policy_optimizer[uid] = self.policy_optimizer[uid].step(
grad, policy, self.schedulers[uid]._lr if self.schedulers[uid] else None
grad, policy, self._learning_rate[uid] if self._learning_rate_scheduler[uid] else None
)

# compute value loss
Expand All @@ -625,7 +623,7 @@ def _update(self, timestep: int, timesteps: int) -> None:
if config.jax.is_distributed:
grad = value.reduce_parameters(grad)
self.value_optimizer[uid] = self.value_optimizer[uid].step(
grad, value, self.schedulers[uid]._lr if self.schedulers[uid] else None
grad, value, self._learning_rate[uid] if self._learning_rate_scheduler[uid] else None
)

# update cumulative losses
Expand All @@ -636,13 +634,15 @@ def _update(self, timestep: int, timesteps: int) -> None:

# update learning rate
if self._learning_rate_scheduler[uid]:
if isinstance(self.schedulers[uid], KLAdaptiveLR):
if self._learning_rate_scheduler[uid] is KLAdaptiveLR:
kl = np.mean(kl_divergences)
# reduce (collect from all workers/processes) KL in distributed runs
if config.jax.is_distributed:
kl = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(kl.reshape(1)).item()
kl /= config.jax.world_size
self.schedulers[uid].step(kl)
self._learning_rate[uid] = self.schedulers[uid](timestep, self._learning_rate[uid], kl)
else:
self._learning_rate[uid] *= self.schedulers[uid](timestep)

# record data
self.track_data(
Expand All @@ -662,4 +662,4 @@ def _update(self, timestep: int, timesteps: int) -> None:
self.track_data(f"Policy / Standard deviation ({uid})", stddev.mean().item())

if self._learning_rate_scheduler[uid]:
self.track_data(f"Learning / Learning rate ({uid})", self.schedulers[uid]._lr)
self.track_data(f"Learning / Learning rate ({uid})", self._learning_rate[uid])
44 changes: 22 additions & 22 deletions skrl/multi_agents/jax/mappo/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"lambda": 0.95, # TD(lambda) coefficient (lam) for computing returns and advantages

"learning_rate": 1e-3, # learning rate
"learning_rate_scheduler": None, # learning rate scheduler class (see torch.optim.lr_scheduler)
"learning_rate_scheduler": None, # learning rate scheduler function (see optax.schedules)
"learning_rate_scheduler_kwargs": {}, # learning rate scheduler's kwargs (e.g. {"step_size": 1e-3})

"state_preprocessor": None, # state preprocessor class (see skrl.resources.preprocessors)
Expand Down Expand Up @@ -55,9 +55,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
"write_interval": 250, # TensorBoard writing interval (timesteps)
"write_interval": "auto", # TensorBoard writing interval (timesteps)

"checkpoint_interval": 1000, # interval for checkpoints (timesteps)
"checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately

"wandb": False, # whether to use Weights & Biases
Expand Down Expand Up @@ -293,7 +293,6 @@ def __init__(

self._learning_rate = self._as_dict(self.cfg["learning_rate"])
self._learning_rate_scheduler = self._as_dict(self.cfg["learning_rate_scheduler"])
self._learning_rate_scheduler_kwargs = self._as_dict(self.cfg["learning_rate_scheduler_kwargs"])

self._state_preprocessor = self._as_dict(self.cfg["state_preprocessor"])
self._state_preprocessor_kwargs = self._as_dict(self.cfg["state_preprocessor_kwargs"])
Expand Down Expand Up @@ -322,24 +321,23 @@ def __init__(

if policy is not None and value is not None:
# scheduler
scale = True
self.schedulers[uid] = None
if self._learning_rate_scheduler[uid] is not None:
if self._learning_rate_scheduler[uid] == KLAdaptiveLR:
scale = False
self.schedulers[uid] = self._learning_rate_scheduler[uid](
self._learning_rate[uid], **self._learning_rate_scheduler_kwargs[uid]
)
else:
self._learning_rate[uid] = self._learning_rate_scheduler[uid](
self._learning_rate[uid], **self._learning_rate_scheduler_kwargs[uid]
)
if self._learning_rate_scheduler[uid]:
self.schedulers[uid] = self._learning_rate_scheduler[uid](
**self._as_dict(self.cfg["learning_rate_scheduler_kwargs"])[uid]
)
# optimizer
self.policy_optimizer[uid] = Adam(
model=policy, lr=self._learning_rate[uid], grad_norm_clip=self._grad_norm_clip[uid], scale=scale
model=policy,
lr=self._learning_rate[uid],
grad_norm_clip=self._grad_norm_clip[uid],
scale=not self._learning_rate_scheduler[uid],
)
self.value_optimizer[uid] = Adam(
model=value, lr=self._learning_rate[uid], grad_norm_clip=self._grad_norm_clip[uid], scale=scale
model=value,
lr=self._learning_rate[uid],
grad_norm_clip=self._grad_norm_clip[uid],
scale=not self._learning_rate_scheduler[uid],
)

self.checkpoint_modules[uid]["policy_optimizer"] = self.policy_optimizer[uid]
Expand Down Expand Up @@ -638,7 +636,7 @@ def _update(self, timestep: int, timesteps: int) -> None:
if config.jax.is_distributed:
grad = policy.reduce_parameters(grad)
self.policy_optimizer[uid] = self.policy_optimizer[uid].step(
grad, policy, self.schedulers[uid]._lr if self.schedulers[uid] else None
grad, policy, self._learning_rate[uid] if self._learning_rate_scheduler[uid] else None
)

# compute value loss
Expand All @@ -657,7 +655,7 @@ def _update(self, timestep: int, timesteps: int) -> None:
if config.jax.is_distributed:
grad = value.reduce_parameters(grad)
self.value_optimizer[uid] = self.value_optimizer[uid].step(
grad, value, self.schedulers[uid]._lr if self.schedulers[uid] else None
grad, value, self._learning_rate[uid] if self._learning_rate_scheduler[uid] else None
)

# update cumulative losses
Expand All @@ -668,13 +666,15 @@ def _update(self, timestep: int, timesteps: int) -> None:

# update learning rate
if self._learning_rate_scheduler[uid]:
if isinstance(self.schedulers[uid], KLAdaptiveLR):
if self._learning_rate_scheduler[uid] is KLAdaptiveLR:
kl = np.mean(kl_divergences)
# reduce (collect from all workers/processes) KL in distributed runs
if config.jax.is_distributed:
kl = jax.pmap(lambda x: jax.lax.psum(x, "i"), axis_name="i")(kl.reshape(1)).item()
kl /= config.jax.world_size
self.schedulers[uid].step(kl)
self._learning_rate[uid] = self.schedulers[uid](timestep, self._learning_rate[uid], kl)
else:
self._learning_rate[uid] *= self.schedulers[uid](timestep)

# record data
self.track_data(
Expand All @@ -694,4 +694,4 @@ def _update(self, timestep: int, timesteps: int) -> None:
self.track_data(f"Policy / Standard deviation ({uid})", stddev.mean().item())

if self._learning_rate_scheduler[uid]:
self.track_data(f"Learning / Learning rate ({uid})", self.schedulers[uid]._lr)
self.track_data(f"Learning / Learning rate ({uid})", self._learning_rate[uid])
6 changes: 3 additions & 3 deletions skrl/multi_agents/torch/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
model.to(model.device)

self.tracking_data = collections.defaultdict(list)
self.write_interval = self.cfg.get("experiment", {}).get("write_interval", 1000)
self.write_interval = self.cfg.get("experiment", {}).get("write_interval", "auto")

self._track_rewards = collections.deque(maxlen=100)
self._track_timesteps = collections.deque(maxlen=100)
Expand All @@ -75,9 +75,9 @@ def __init__(

# checkpoint
self.checkpoint_modules = {uid: {} for uid in self.possible_agents}
self.checkpoint_interval = self.cfg.get("experiment", {}).get("checkpoint_interval", 1000)
self.checkpoint_interval = self.cfg.get("experiment", {}).get("checkpoint_interval", "auto")
self.checkpoint_store_separately = self.cfg.get("experiment", {}).get("store_separately", False)
self.checkpoint_best_modules = {"timestep": 0, "reward": -(2**31), "saved": True, "modules": {}}
self.checkpoint_best_modules = {"timestep": 0, "reward": -(2**31), "saved": False, "modules": {}}

# experiment directory
directory = self.cfg.get("experiment", {}).get("directory", "")
Expand Down
4 changes: 2 additions & 2 deletions skrl/multi_agents/torch/ippo/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
"write_interval": 250, # TensorBoard writing interval (timesteps)
"write_interval": "auto", # TensorBoard writing interval (timesteps)

"checkpoint_interval": 1000, # interval for checkpoints (timesteps)
"checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately

"wandb": False, # whether to use Weights & Biases
Expand Down
4 changes: 2 additions & 2 deletions skrl/multi_agents/torch/mappo/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@
"experiment": {
"directory": "", # experiment's parent directory
"experiment_name": "", # experiment name
"write_interval": 250, # TensorBoard writing interval (timesteps)
"write_interval": "auto", # TensorBoard writing interval (timesteps)

"checkpoint_interval": 1000, # interval for checkpoints (timesteps)
"checkpoint_interval": "auto", # interval for checkpoints (timesteps)
"store_separately": False, # whether to store checkpoints separately

"wandb": False, # whether to use Weights & Biases
Expand Down

0 comments on commit 46526b5

Please sign in to comment.