Skip to content

Commit

Permalink
Move the KL reduction implementation to each agent using it in distri…
Browse files Browse the repository at this point in the history
…buted runs (#169)

* Remove the all-reduction and broadcasting implementation in distributed runs

* All-reduce the KL on each agent implementation

* Increase MINOR version and update CHANGELOG
  • Loading branch information
Toni-SM committed Jul 11, 2024
1 parent 636936f commit 3f1c734
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 34 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [1.3.0] - Unreleased
### Changed
- Move the KL reduction in distributed runs from the `KLAdaptiveLR` class to each agent using it

## [1.2.0] - 2024-06-23
### Added
- Define the `environment_info` trainer config to log environment info (PyTorch implementation)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

# project information
project = "skrl"
copyright = "2021, Toni-SM"
copyright = "2021-2024, Toni-SM"
author = "Toni-SM"

if skrl.__version__ != "unknown":
release = version = skrl.__version__
else:
release = version = "1.2.0"
release = version = "1.3.0"

master_doc = "index"

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "skrl"
version = "1.2.0"
version = "1.3.0"
description = "Modular and flexible library for reinforcement learning on PyTorch and JAX"
readme = "README.md"
requires-python = ">=3.6"
Expand Down
7 changes: 6 additions & 1 deletion skrl/agents/torch/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,12 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
kl = torch.tensor(kl_divergences, device=self.device).mean()
# reduce (collect from all workers/processes) KL in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size
self.scheduler.step(kl.item())
else:
self.scheduler.step()

Expand Down
7 changes: 6 additions & 1 deletion skrl/agents/torch/a2c/a2c_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,12 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
kl = torch.tensor(kl_divergences, device=self.device).mean()
# reduce (collect from all workers/processes) KL in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size
self.scheduler.step(kl.item())
else:
self.scheduler.step()

Expand Down
7 changes: 6 additions & 1 deletion skrl/agents/torch/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,12 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
kl = torch.tensor(kl_divergences, device=self.device).mean()
# reduce (collect from all workers/processes) KL in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size
self.scheduler.step(kl.item())
else:
self.scheduler.step()

Expand Down
7 changes: 6 additions & 1 deletion skrl/agents/torch/ppo/ppo_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,12 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
kl = torch.tensor(kl_divergences, device=self.device).mean()
# reduce (collect from all workers/processes) KL in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size
self.scheduler.step(kl.item())
else:
self.scheduler.step()

Expand Down
7 changes: 6 additions & 1 deletion skrl/agents/torch/rpo/rpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,12 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
kl = torch.tensor(kl_divergences, device=self.device).mean()
# reduce (collect from all workers/processes) KL in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size
self.scheduler.step(kl.item())
else:
self.scheduler.step()

Expand Down
7 changes: 6 additions & 1 deletion skrl/agents/torch/rpo/rpo_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,12 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
kl = torch.tensor(kl_divergences, device=self.device).mean()
# reduce (collect from all workers/processes) KL in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size
self.scheduler.step(kl.item())
else:
self.scheduler.step()

Expand Down
29 changes: 4 additions & 25 deletions skrl/resources/schedulers/torch/kl_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import torch
from torch.optim.lr_scheduler import _LRScheduler

from skrl import config


class KLAdaptiveLR(_LRScheduler):
def __init__(self,
Expand All @@ -29,10 +27,6 @@ def __init__(self,
This scheduler is only available for PPO at the moment.
Applying it to other agents will not change the learning rate
.. note::
In distributed runs, the learning rate will be reduced and broadcasted across all workers/processes
Example::
>>> scheduler = KLAdaptiveLR(optimizer, kl_threshold=0.01)
Expand Down Expand Up @@ -92,25 +86,10 @@ def step(self, kl: Optional[Union[torch.Tensor, float]] = None, epoch: Optional[
:type epoch: int, optional
"""
if kl is not None:
# reduce (collect from all workers/processes) learning rate in distributed runs
if config.torch.is_distributed:
torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
kl /= config.torch.world_size

for i, group in enumerate(self.optimizer.param_groups):
# adjust the learning rate
lr = group['lr']
for group in self.optimizer.param_groups:
if kl > self.kl_threshold * self._kl_factor:
lr = max(lr / self._lr_factor, self.min_lr)
group['lr'] = max(group['lr'] / self._lr_factor, self.min_lr)
elif kl < self.kl_threshold / self._kl_factor:
lr = min(lr * self._lr_factor, self.max_lr)

# broadcast learning rate in distributed runs
if config.torch.is_distributed:
lr_tensor = torch.tensor([lr], device=config.torch.device)
torch.distributed.broadcast(lr_tensor, 0)
lr = lr_tensor.item()
group['lr'] = min(group['lr'] * self._lr_factor, self.max_lr)

# update value
group['lr'] = lr
self._last_lr[i] = lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

0 comments on commit 3f1c734

Please sign in to comment.