From 3f1c73447a803a091538396d246976143cdb4603 Mon Sep 17 00:00:00 2001 From: Toni-SM Date: Wed, 10 Jul 2024 21:59:26 -0400 Subject: [PATCH] Move the KL reduction implementation to each agent using it in distributed runs (#169) * Remove the all-reduction and broadcasting implementation in distributed runs * All-reduce the KL on each agent implementation * Increase MINOR version and update CHANGELOG --- CHANGELOG.md | 4 +++ docs/source/conf.py | 4 +-- pyproject.toml | 2 +- skrl/agents/torch/a2c/a2c.py | 7 ++++- skrl/agents/torch/a2c/a2c_rnn.py | 7 ++++- skrl/agents/torch/ppo/ppo.py | 7 ++++- skrl/agents/torch/ppo/ppo_rnn.py | 7 ++++- skrl/agents/torch/rpo/rpo.py | 7 ++++- skrl/agents/torch/rpo/rpo_rnn.py | 7 ++++- .../resources/schedulers/torch/kl_adaptive.py | 29 +++---------------- 10 files changed, 47 insertions(+), 34 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cb058434..1c151019 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## [1.3.0] - Unreleased +### Changed +- Move the KL reduction in distributed runs from the `KLAdaptiveLR` class to each agent using it + ## [1.2.0] - 2024-06-23 ### Added - Define the `environment_info` trainer config to log environment info (PyTorch implementation) diff --git a/docs/source/conf.py b/docs/source/conf.py index 42df4615..0e8aa659 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,13 +10,13 @@ # project information project = "skrl" -copyright = "2021, Toni-SM" +copyright = "2021-2024, Toni-SM" author = "Toni-SM" if skrl.__version__ != "unknown": release = version = skrl.__version__ else: - release = version = "1.2.0" + release = version = "1.3.0" master_doc = "index" diff --git a/pyproject.toml b/pyproject.toml index cceeb007..9fe1cf9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "skrl" -version = "1.2.0" +version = "1.3.0" description = "Modular and flexible library for reinforcement learning on PyTorch and JAX" readme = "README.md" requires-python = ">=3.6" diff --git a/skrl/agents/torch/a2c/a2c.py b/skrl/agents/torch/a2c/a2c.py index fd495516..d5b80687 100644 --- a/skrl/agents/torch/a2c/a2c.py +++ b/skrl/agents/torch/a2c/a2c.py @@ -420,7 +420,12 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler: if isinstance(self.scheduler, KLAdaptiveLR): - self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean()) + kl = torch.tensor(kl_divergences, device=self.device).mean() + # reduce (collect from all workers/processes) KL in distributed runs + if config.torch.is_distributed: + torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM) + kl /= config.torch.world_size + self.scheduler.step(kl.item()) else: self.scheduler.step() diff --git a/skrl/agents/torch/a2c/a2c_rnn.py b/skrl/agents/torch/a2c/a2c_rnn.py index 8cecc41e..b21f39e7 100644 --- a/skrl/agents/torch/a2c/a2c_rnn.py +++ b/skrl/agents/torch/a2c/a2c_rnn.py @@ -491,7 +491,12 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler: if isinstance(self.scheduler, KLAdaptiveLR): - self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean()) + kl = torch.tensor(kl_divergences, device=self.device).mean() + # reduce (collect from all workers/processes) KL in distributed runs + if config.torch.is_distributed: + torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM) + kl /= config.torch.world_size + self.scheduler.step(kl.item()) else: self.scheduler.step() diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py index cb68ac2a..3143505f 100644 --- a/skrl/agents/torch/ppo/ppo.py +++ b/skrl/agents/torch/ppo/ppo.py @@ -447,7 +447,12 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler: if isinstance(self.scheduler, KLAdaptiveLR): - self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean()) + kl = torch.tensor(kl_divergences, device=self.device).mean() + # reduce (collect from all workers/processes) KL in distributed runs + if config.torch.is_distributed: + torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM) + kl /= config.torch.world_size + self.scheduler.step(kl.item()) else: self.scheduler.step() diff --git a/skrl/agents/torch/ppo/ppo_rnn.py b/skrl/agents/torch/ppo/ppo_rnn.py index 1ec4e7ca..48a488ff 100644 --- a/skrl/agents/torch/ppo/ppo_rnn.py +++ b/skrl/agents/torch/ppo/ppo_rnn.py @@ -519,7 +519,12 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler: if isinstance(self.scheduler, KLAdaptiveLR): - self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean()) + kl = torch.tensor(kl_divergences, device=self.device).mean() + # reduce (collect from all workers/processes) KL in distributed runs + if config.torch.is_distributed: + torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM) + kl /= config.torch.world_size + self.scheduler.step(kl.item()) else: self.scheduler.step() diff --git a/skrl/agents/torch/rpo/rpo.py b/skrl/agents/torch/rpo/rpo.py index dfb6fa9f..a0024d13 100644 --- a/skrl/agents/torch/rpo/rpo.py +++ b/skrl/agents/torch/rpo/rpo.py @@ -449,7 +449,12 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler: if isinstance(self.scheduler, KLAdaptiveLR): - self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean()) + kl = torch.tensor(kl_divergences, device=self.device).mean() + # reduce (collect from all workers/processes) KL in distributed runs + if config.torch.is_distributed: + torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM) + kl /= config.torch.world_size + self.scheduler.step(kl.item()) else: self.scheduler.step() diff --git a/skrl/agents/torch/rpo/rpo_rnn.py b/skrl/agents/torch/rpo/rpo_rnn.py index 94c7c27a..efda1fed 100644 --- a/skrl/agents/torch/rpo/rpo_rnn.py +++ b/skrl/agents/torch/rpo/rpo_rnn.py @@ -521,7 +521,12 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler: if isinstance(self.scheduler, KLAdaptiveLR): - self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean()) + kl = torch.tensor(kl_divergences, device=self.device).mean() + # reduce (collect from all workers/processes) KL in distributed runs + if config.torch.is_distributed: + torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM) + kl /= config.torch.world_size + self.scheduler.step(kl.item()) else: self.scheduler.step() diff --git a/skrl/resources/schedulers/torch/kl_adaptive.py b/skrl/resources/schedulers/torch/kl_adaptive.py index 4d7d5771..0a32eed3 100644 --- a/skrl/resources/schedulers/torch/kl_adaptive.py +++ b/skrl/resources/schedulers/torch/kl_adaptive.py @@ -5,8 +5,6 @@ import torch from torch.optim.lr_scheduler import _LRScheduler -from skrl import config - class KLAdaptiveLR(_LRScheduler): def __init__(self, @@ -29,10 +27,6 @@ def __init__(self, This scheduler is only available for PPO at the moment. Applying it to other agents will not change the learning rate - .. note:: - - In distributed runs, the learning rate will be reduced and broadcasted across all workers/processes - Example:: >>> scheduler = KLAdaptiveLR(optimizer, kl_threshold=0.01) @@ -92,25 +86,10 @@ def step(self, kl: Optional[Union[torch.Tensor, float]] = None, epoch: Optional[ :type epoch: int, optional """ if kl is not None: - # reduce (collect from all workers/processes) learning rate in distributed runs - if config.torch.is_distributed: - torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM) - kl /= config.torch.world_size - - for i, group in enumerate(self.optimizer.param_groups): - # adjust the learning rate - lr = group['lr'] + for group in self.optimizer.param_groups: if kl > self.kl_threshold * self._kl_factor: - lr = max(lr / self._lr_factor, self.min_lr) + group['lr'] = max(group['lr'] / self._lr_factor, self.min_lr) elif kl < self.kl_threshold / self._kl_factor: - lr = min(lr * self._lr_factor, self.max_lr) - - # broadcast learning rate in distributed runs - if config.torch.is_distributed: - lr_tensor = torch.tensor([lr], device=config.torch.device) - torch.distributed.broadcast(lr_tensor, 0) - lr = lr_tensor.item() + group['lr'] = min(group['lr'] * self._lr_factor, self.max_lr) - # update value - group['lr'] = lr - self._last_lr[i] = lr + self._last_lr = [group['lr'] for group in self.optimizer.param_groups]