From 391506dae17f0b787b2a42738fdae8d94a9fbbdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Antonio=20Serrano=20Mu=C3=B1oz?= Date: Thu, 20 Jun 2024 22:43:50 -0400 Subject: [PATCH] Reduce and broadcast learning rate across all workers/processes --- .../resources/schedulers/torch/kl_adaptive.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/skrl/resources/schedulers/torch/kl_adaptive.py b/skrl/resources/schedulers/torch/kl_adaptive.py index 0a32eed3..4d7d5771 100644 --- a/skrl/resources/schedulers/torch/kl_adaptive.py +++ b/skrl/resources/schedulers/torch/kl_adaptive.py @@ -5,6 +5,8 @@ import torch from torch.optim.lr_scheduler import _LRScheduler +from skrl import config + class KLAdaptiveLR(_LRScheduler): def __init__(self, @@ -27,6 +29,10 @@ def __init__(self, This scheduler is only available for PPO at the moment. Applying it to other agents will not change the learning rate + .. note:: + + In distributed runs, the learning rate will be reduced and broadcasted across all workers/processes + Example:: >>> scheduler = KLAdaptiveLR(optimizer, kl_threshold=0.01) @@ -86,10 +92,25 @@ def step(self, kl: Optional[Union[torch.Tensor, float]] = None, epoch: Optional[ :type epoch: int, optional """ if kl is not None: - for group in self.optimizer.param_groups: + # reduce (collect from all workers/processes) learning rate in distributed runs + if config.torch.is_distributed: + torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM) + kl /= config.torch.world_size + + for i, group in enumerate(self.optimizer.param_groups): + # adjust the learning rate + lr = group['lr'] if kl > self.kl_threshold * self._kl_factor: - group['lr'] = max(group['lr'] / self._lr_factor, self.min_lr) + lr = max(lr / self._lr_factor, self.min_lr) elif kl < self.kl_threshold / self._kl_factor: - group['lr'] = min(group['lr'] * self._lr_factor, self.max_lr) + lr = min(lr * self._lr_factor, self.max_lr) + + # broadcast learning rate in distributed runs + if config.torch.is_distributed: + lr_tensor = torch.tensor([lr], device=config.torch.device) + torch.distributed.broadcast(lr_tensor, 0) + lr = lr_tensor.item() - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + # update value + group['lr'] = lr + self._last_lr[i] = lr