Skip to content

Commit

Permalink
Reduce and broadcast learning rate across all workers/processes
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jun 21, 2024
1 parent 9f265af commit 391506d
Showing 1 changed file with 25 additions and 4 deletions.
29 changes: 25 additions & 4 deletions skrl/resources/schedulers/torch/kl_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from torch.optim.lr_scheduler import _LRScheduler

from skrl import config


class KLAdaptiveLR(_LRScheduler):
def __init__(self,
Expand All @@ -27,6 +29,10 @@ def __init__(self,
This scheduler is only available for PPO at the moment.
Applying it to other agents will not change the learning rate
.. note::
In distributed runs, the learning rate will be reduced and broadcasted across all workers/processes
Example::
>>> scheduler = KLAdaptiveLR(optimizer, kl_threshold=0.01)
Expand Down Expand Up @@ -86,10 +92,25 @@ def step(self, kl: Optional[Union[torch.Tensor, float]] = None, epoch: Optional[
:type epoch: int, optional
"""
if kl is not None:
for group in self.optimizer.param_groups:
# reduce (collect from all workers/processes) learning rate in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size

for i, group in enumerate(self.optimizer.param_groups):
# adjust the learning rate
lr = group['lr']
if kl > self.kl_threshold * self._kl_factor:
group['lr'] = max(group['lr'] / self._lr_factor, self.min_lr)
lr = max(lr / self._lr_factor, self.min_lr)
elif kl < self.kl_threshold / self._kl_factor:
group['lr'] = min(group['lr'] * self._lr_factor, self.max_lr)
lr = min(lr * self._lr_factor, self.max_lr)

# broadcast learning rate in distributed runs
if config.torch.is_distributed:
lr_tensor = torch.tensor([lr], device=config.torch.device)
torch.distributed.broadcast(lr_tensor, 0)
lr = lr_tensor.item()

self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
# update value
group['lr'] = lr
self._last_lr[i] = lr

0 comments on commit 391506d

Please sign in to comment.