diff --git a/CHANGELOG.md b/CHANGELOG.md
index 5b52e4b3..71132f6e 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Define the `environment_info` trainer config to log environment info (PyTorch implementation)
- Add support to automatically compute the write and checkpoint intervals and make it the default option
- Single forward-pass in shared models
+- Distributed multi-GPU and multi-node learning (PyTorch implementation)
### Changed
- Update Orbit-related source code and docs to Isaac Lab
diff --git a/docs/source/api/agents/a2c.rst b/docs/source/api/agents/a2c.rst
index c98f6448..dfb55a02 100644
--- a/docs/source/api/agents/a2c.rst
+++ b/docs/source/api/agents/a2c.rst
@@ -232,6 +232,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/amp.rst b/docs/source/api/agents/amp.rst
index c993b67a..e6937f76 100644
--- a/docs/source/api/agents/amp.rst
+++ b/docs/source/api/agents/amp.rst
@@ -237,6 +237,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/cem.rst b/docs/source/api/agents/cem.rst
index 68245818..1bcbee01 100644
--- a/docs/source/api/agents/cem.rst
+++ b/docs/source/api/agents/cem.rst
@@ -175,6 +175,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
+ * - Distributed
+ - \-
+ - .. centered:: :math:`\square`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/ddpg.rst b/docs/source/api/agents/ddpg.rst
index 7bfff2e3..461fda20 100644
--- a/docs/source/api/agents/ddpg.rst
+++ b/docs/source/api/agents/ddpg.rst
@@ -236,6 +236,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/ddqn.rst b/docs/source/api/agents/ddqn.rst
index 91744c33..f2ac0029 100644
--- a/docs/source/api/agents/ddqn.rst
+++ b/docs/source/api/agents/ddqn.rst
@@ -184,6 +184,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/dqn.rst b/docs/source/api/agents/dqn.rst
index 605e57ad..ed196366 100644
--- a/docs/source/api/agents/dqn.rst
+++ b/docs/source/api/agents/dqn.rst
@@ -184,6 +184,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/ppo.rst b/docs/source/api/agents/ppo.rst
index 0bf7b37b..f3c84f20 100644
--- a/docs/source/api/agents/ppo.rst
+++ b/docs/source/api/agents/ppo.rst
@@ -248,6 +248,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/rpo.rst b/docs/source/api/agents/rpo.rst
index 61947dff..f769e9c7 100644
--- a/docs/source/api/agents/rpo.rst
+++ b/docs/source/api/agents/rpo.rst
@@ -285,6 +285,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/sac.rst b/docs/source/api/agents/sac.rst
index c55720cc..2c17fc3a 100644
--- a/docs/source/api/agents/sac.rst
+++ b/docs/source/api/agents/sac.rst
@@ -244,6 +244,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/td3.rst b/docs/source/api/agents/td3.rst
index 4c494f5c..5d54ae4f 100644
--- a/docs/source/api/agents/td3.rst
+++ b/docs/source/api/agents/td3.rst
@@ -258,6 +258,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/agents/trpo.rst b/docs/source/api/agents/trpo.rst
index 21460fe9..c482bf87 100644
--- a/docs/source/api/agents/trpo.rst
+++ b/docs/source/api/agents/trpo.rst
@@ -282,6 +282,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/config/frameworks.rst b/docs/source/api/config/frameworks.rst
index eff8bb77..f25e492f 100644
--- a/docs/source/api/config/frameworks.rst
+++ b/docs/source/api/config/frameworks.rst
@@ -7,6 +7,65 @@ Configurations for behavior modification of Machine Learning (ML) frameworks.
+PyTorch
+-------
+
+PyTorch specific configuration
+
+.. raw:: html
+
+
+
+API
+^^^
+
+.. py:data:: skrl.config.torch.device
+ :type: torch.device
+ :value: "cuda:${LOCAL_RANK}" | "cpu"
+
+ Default device
+
+ The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise
+
+.. py:data:: skrl.config.local_rank
+ :type: int
+ :value: 0
+
+ The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)
+
+ This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist).
+ See `torch.distributed `_ for more details
+
+.. py:data:: skrl.config.rank
+ :type: int
+ :value: 0
+
+ The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)
+
+ This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist).
+ See `torch.distributed `_ for more details
+
+.. py:data:: skrl.config.world_size
+ :type: int
+ :value: 1
+
+ The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)
+
+ This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist).
+ See `torch.distributed `_ for more details
+
+.. py:data:: skrl.config.is_distributed
+ :type: bool
+ :value: False
+
+ Whether if running in a distributed environment
+
+ This property is ``True`` when the PyTorch's distributed environment variable ``WORLD_SIZE > 1``
+
+.. raw:: html
+
+
+
JAX
---
diff --git a/docs/source/api/multi_agents/ippo.rst b/docs/source/api/multi_agents/ippo.rst
index 9a259326..862fd2b9 100644
--- a/docs/source/api/multi_agents/ippo.rst
+++ b/docs/source/api/multi_agents/ippo.rst
@@ -239,6 +239,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/multi_agents/mappo.rst b/docs/source/api/multi_agents/mappo.rst
index c875ac6a..7ae18de3 100644
--- a/docs/source/api/multi_agents/mappo.rst
+++ b/docs/source/api/multi_agents/mappo.rst
@@ -240,6 +240,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
+ * - Distributed
+ - Single Program Multi Data (SPMD) multi-GPU
+ - .. centered:: :math:`\blacksquare`
+ - .. centered:: :math:`\square`
.. raw:: html
diff --git a/docs/source/api/utils.rst b/docs/source/api/utils.rst
index a86fb2da..ff050aaa 100644
--- a/docs/source/api/utils.rst
+++ b/docs/source/api/utils.rst
@@ -25,7 +25,7 @@ A set of utilities and configurations for managing an RL setup is provided as pa
- .. centered:: |_4| |pytorch| |_4|
- .. centered:: |_4| |jax| |_4|
* - :doc:`ML frameworks ` configuration |_5| |_5| |_5| |_5| |_5| |_2|
- - .. centered:: :math:`\square`
+ - .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\blacksquare`
.. list-table::
diff --git a/skrl/__init__.py b/skrl/__init__.py
index 068206ee..e3e4fa0c 100644
--- a/skrl/__init__.py
+++ b/skrl/__init__.py
@@ -1,6 +1,7 @@
from typing import Union
import logging
+import os
import sys
import numpy as np
@@ -43,6 +44,69 @@ class _Config(object):
def __init__(self) -> None:
"""Machine learning framework specific configuration
"""
+
+ class PyTorch(object):
+ def __init__(self) -> None:
+ """PyTorch configuration
+ """
+ self._device = None
+ # torch.distributed config
+ self._local_rank = int(os.getenv("LOCAL_RANK", "0"))
+ self._rank = int(os.getenv("RANK", "0"))
+ self._world_size = int(os.getenv("WORLD_SIZE", "1"))
+ self._is_distributed = self._world_size > 1
+
+ @property
+ def local_rank(self) -> int:
+ """The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)
+
+ This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist)
+ """
+ return self._local_rank
+
+ @property
+ def rank(self) -> int:
+ """The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)
+
+ This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist)
+ """
+ return self._rank
+
+ @property
+ def world_size(self) -> int:
+ """The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)
+
+ This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist)
+ """
+ return self._world_size
+
+ @property
+ def is_distributed(self) -> bool:
+ """Whether if running in a distributed environment
+
+ This property is ``True`` when the PyTorch's distributed environment variable ``WORLD_SIZE > 1``
+ """
+ return self._is_distributed
+
+ @property
+ def device(self) -> "torch.device":
+ """Default device
+
+ The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment)
+ if CUDA is available, ``cpu`` otherwise
+ """
+ try:
+ import torch
+ if self._device is None:
+ return torch.device(f"cuda:{self._local_rank}" if torch.cuda.is_available() else "cpu")
+ return torch.device(self._device)
+ except ImportError:
+ return self._device
+
+ @device.setter
+ def device(self, device: Union[str, "torch.device"]) -> None:
+ self._device = device
+
class JAX(object):
def __init__(self) -> None:
"""JAX configuration
@@ -91,5 +155,6 @@ def key(self, value: Union[int, "jax.Array"]) -> None:
self._key = value
self.jax = JAX()
+ self.torch = PyTorch()
config = _Config()
diff --git a/skrl/agents/torch/a2c/a2c.py b/skrl/agents/torch/a2c/a2c.py
index 84f28846..fd495516 100644
--- a/skrl/agents/torch/a2c/a2c.py
+++ b/skrl/agents/torch/a2c/a2c.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -104,6 +105,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None and self.policy is not self.value:
+ self.value.broadcast_parameters()
+
# configuration
self._mini_batches = self.cfg["mini_batches"]
self._rollouts = self.cfg["rollouts"]
@@ -391,6 +400,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+ if self.policy is not self.value:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
@@ -407,7 +420,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
- self.scheduler.step(torch.tensor(kl_divergences).mean())
+ self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()
diff --git a/skrl/agents/torch/a2c/a2c_rnn.py b/skrl/agents/torch/a2c/a2c_rnn.py
index 1053b3de..8cecc41e 100644
--- a/skrl/agents/torch/a2c/a2c_rnn.py
+++ b/skrl/agents/torch/a2c/a2c_rnn.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -104,6 +105,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None and self.policy is not self.value:
+ self.value.broadcast_parameters()
+
# configuration
self._mini_batches = self.cfg["mini_batches"]
self._rollouts = self.cfg["rollouts"]
@@ -462,6 +471,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+ if self.policy is not self.value:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
@@ -478,7 +491,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
- self.scheduler.step(torch.tensor(kl_divergences).mean())
+ self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()
diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py
index 743fa804..1a648c22 100644
--- a/skrl/agents/torch/amp/amp.py
+++ b/skrl/agents/torch/amp/amp.py
@@ -10,6 +10,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -147,6 +148,16 @@ def __init__(self,
self.checkpoint_modules["value"] = self.value
self.checkpoint_modules["discriminator"] = self.discriminator
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None:
+ self.value.broadcast_parameters()
+ if self.discriminator is not None:
+ self.discriminator.broadcast_parameters()
+
# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
@@ -554,6 +565,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss + discriminator_loss).backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+ self.value.reduce_parameters()
+ self.discriminator.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(),
self.value.parameters(),
@@ -571,7 +586,7 @@ def compute_gae(rewards: torch.Tensor,
if self._learning_rate_scheduler:
self.scheduler.step()
- # update AMP repaly buffer
+ # update AMP replay buffer
self.reply_buffer.add_samples(states=amp_states.view(-1, amp_states.shape[-1]))
# record data
diff --git a/skrl/agents/torch/base.py b/skrl/agents/torch/base.py
index db95e020..d23f27cc 100644
--- a/skrl/agents/torch/base.py
+++ b/skrl/agents/torch/base.py
@@ -11,7 +11,7 @@
import torch
from torch.utils.tensorboard import SummaryWriter
-from skrl import logger
+from skrl import config, logger
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -85,6 +85,12 @@ def __init__(self,
experiment_name = "{}_{}".format(datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S-%f"), self.__class__.__name__)
self.experiment_dir = os.path.join(directory, experiment_name)
+ # set up distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Distributed (rank: {config.torch.rank}, local rank: {config.torch.local_rank}, world size: {config.torch.world_size})")
+ torch.distributed.init_process_group("nccl", rank=config.torch.rank, world_size=config.torch.world_size)
+ torch.cuda.set_device(config.torch.local_rank)
+
def __str__(self) -> str:
"""Generate a representation of the agent as string
@@ -129,26 +135,33 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None:
"""Initialize the agent
This method should be called before the agent is used.
- It will initialize the TensoBoard writer (and optionally Weights & Biases) and create the checkpoints directory
+ It will initialize the TensorBoard writer (and optionally Weights & Biases) and create the checkpoints directory
:param trainer_cfg: Trainer configuration
:type trainer_cfg: dict, optional
"""
trainer_cfg = trainer_cfg if trainer_cfg is not None else {}
+
+ # update agent configuration to avoid duplicated logging/checking in distributed runs
+ if config.torch.is_distributed and config.torch.rank:
+ self.write_interval = 0
+ self.checkpoint_interval = 0
+ # TODO: disable wandb
+
# setup Weights & Biases
if self.cfg.get("experiment", {}).get("wandb", False):
- # save experiment config
+ # save experiment configuration
try:
models_cfg = {k: v.net._modules for (k, v) in self.models.items()}
except AttributeError:
models_cfg = {k: v._modules for (k, v) in self.models.items()}
- config={**self.cfg, **trainer_cfg, **models_cfg}
+ wandb_config={**self.cfg, **trainer_cfg, **models_cfg}
# set default values
wandb_kwargs = copy.deepcopy(self.cfg.get("experiment", {}).get("wandb_kwargs", {}))
wandb_kwargs.setdefault("name", os.path.split(self.experiment_dir)[-1])
wandb_kwargs.setdefault("sync_tensorboard", True)
wandb_kwargs.setdefault("config", {})
- wandb_kwargs["config"].update(config)
+ wandb_kwargs["config"].update(wandb_config)
# init Weights & Biases
import wandb
wandb.init(**wandb_kwargs)
@@ -386,7 +399,7 @@ def migrate(self,
name_map: Mapping[str, Mapping[str, str]] = {},
auto_mapping: bool = True,
verbose: bool = False) -> bool:
- """Migrate the specified extrernal checkpoint to the current agent
+ """Migrate the specified external checkpoint to the current agent
The final storage device is determined by the constructor of the agent.
Only files generated by the *rl_games* library are supported at the moment
diff --git a/skrl/agents/torch/ddpg/ddpg.py b/skrl/agents/torch/ddpg/ddpg.py
index ddcf0280..88f6ccc9 100644
--- a/skrl/agents/torch/ddpg/ddpg.py
+++ b/skrl/agents/torch/ddpg/ddpg.py
@@ -8,6 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -109,6 +110,14 @@ def __init__(self,
self.checkpoint_modules["critic"] = self.critic
self.checkpoint_modules["target_critic"] = self.target_critic
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.critic is not None:
+ self.critic.broadcast_parameters()
+
if self.target_policy is not None and self.target_critic is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_policy.freeze_parameters(True)
@@ -341,6 +350,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
+ if config.torch.is_distributed:
+ self.critic.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.critic.parameters(), self._grad_norm_clip)
self.critic_optimizer.step()
@@ -354,6 +365,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
diff --git a/skrl/agents/torch/ddpg/ddpg_rnn.py b/skrl/agents/torch/ddpg/ddpg_rnn.py
index 400873ef..436184c1 100644
--- a/skrl/agents/torch/ddpg/ddpg_rnn.py
+++ b/skrl/agents/torch/ddpg/ddpg_rnn.py
@@ -8,6 +8,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -109,6 +110,14 @@ def __init__(self,
self.checkpoint_modules["critic"] = self.critic
self.checkpoint_modules["target_critic"] = self.target_critic
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.critic is not None:
+ self.critic.broadcast_parameters()
+
if self.target_policy is not None and self.target_critic is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_policy.freeze_parameters(True)
@@ -382,6 +391,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
+ if config.torch.is_distributed:
+ self.critic.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.critic.parameters(), self._grad_norm_clip)
self.critic_optimizer.step()
@@ -395,6 +406,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
diff --git a/skrl/agents/torch/dqn/ddqn.py b/skrl/agents/torch/dqn/ddqn.py
index ac18c750..8d181ac8 100644
--- a/skrl/agents/torch/dqn/ddqn.py
+++ b/skrl/agents/torch/dqn/ddqn.py
@@ -8,6 +8,7 @@
import torch
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -104,6 +105,12 @@ def __init__(self,
self.checkpoint_modules["q_network"] = self.q_network
self.checkpoint_modules["target_q_network"] = self.target_q_network
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.q_network is not None:
+ self.q_network.broadcast_parameters()
+
if self.target_q_network is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_q_network.freeze_parameters(True)
@@ -303,6 +310,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimize Q-network
self.optimizer.zero_grad()
q_network_loss.backward()
+ if config.torch.is_distributed:
+ self.q_network.reduce_parameters()
self.optimizer.step()
# update target network
diff --git a/skrl/agents/torch/dqn/dqn.py b/skrl/agents/torch/dqn/dqn.py
index c23c611f..c7f4f709 100644
--- a/skrl/agents/torch/dqn/dqn.py
+++ b/skrl/agents/torch/dqn/dqn.py
@@ -8,6 +8,7 @@
import torch
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -104,6 +105,12 @@ def __init__(self,
self.checkpoint_modules["q_network"] = self.q_network
self.checkpoint_modules["target_q_network"] = self.target_q_network
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.q_network is not None:
+ self.q_network.broadcast_parameters()
+
if self.target_q_network is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_q_network.freeze_parameters(True)
@@ -303,6 +310,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimize Q-network
self.optimizer.zero_grad()
q_network_loss.backward()
+ if config.torch.is_distributed:
+ self.q_network.reduce_parameters()
self.optimizer.step()
# update target network
diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py
index 7e1dc24f..cb68ac2a 100644
--- a/skrl/agents/torch/ppo/ppo.py
+++ b/skrl/agents/torch/ppo/ppo.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -111,6 +112,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None and self.policy is not self.value:
+ self.value.broadcast_parameters()
+
# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
@@ -418,6 +427,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+ if self.policy is not self.value:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
@@ -434,7 +447,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
- self.scheduler.step(torch.tensor(kl_divergences).mean())
+ self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()
diff --git a/skrl/agents/torch/ppo/ppo_rnn.py b/skrl/agents/torch/ppo/ppo_rnn.py
index c74ca937..1ec4e7ca 100644
--- a/skrl/agents/torch/ppo/ppo_rnn.py
+++ b/skrl/agents/torch/ppo/ppo_rnn.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -111,6 +112,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None and self.policy is not self.value:
+ self.value.broadcast_parameters()
+
# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
@@ -490,6 +499,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+ if self.policy is not self.value:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
@@ -506,7 +519,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
- self.scheduler.step(torch.tensor(kl_divergences).mean())
+ self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()
diff --git a/skrl/agents/torch/rpo/rpo.py b/skrl/agents/torch/rpo/rpo.py
index 93a59435..dfb6fa9f 100644
--- a/skrl/agents/torch/rpo/rpo.py
+++ b/skrl/agents/torch/rpo/rpo.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -112,6 +113,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None and self.policy is not self.value:
+ self.value.broadcast_parameters()
+
# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
@@ -420,6 +429,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+ if self.policy is not self.value:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
@@ -436,7 +449,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
- self.scheduler.step(torch.tensor(kl_divergences).mean())
+ self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()
diff --git a/skrl/agents/torch/rpo/rpo_rnn.py b/skrl/agents/torch/rpo/rpo_rnn.py
index f6d507f9..94c7c27a 100644
--- a/skrl/agents/torch/rpo/rpo_rnn.py
+++ b/skrl/agents/torch/rpo/rpo_rnn.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -112,6 +113,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None and self.policy is not self.value:
+ self.value.broadcast_parameters()
+
# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
@@ -492,6 +501,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+ if self.policy is not self.value:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
@@ -508,7 +521,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
- self.scheduler.step(torch.tensor(kl_divergences).mean())
+ self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()
diff --git a/skrl/agents/torch/sac/sac.py b/skrl/agents/torch/sac/sac.py
index d6b96270..dc8678a6 100644
--- a/skrl/agents/torch/sac/sac.py
+++ b/skrl/agents/torch/sac/sac.py
@@ -10,6 +10,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -111,6 +112,16 @@ def __init__(self,
self.checkpoint_modules["target_critic_1"] = self.target_critic_1
self.checkpoint_modules["target_critic_2"] = self.target_critic_2
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.critic_1 is not None:
+ self.critic_1.broadcast_parameters()
+ if self.critic_2 is not None:
+ self.critic_2.broadcast_parameters()
+
if self.target_critic_1 is not None and self.target_critic_2 is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_critic_1.freeze_parameters(True)
@@ -325,6 +336,9 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
+ if config.torch.is_distributed:
+ self.critic_1.reduce_parameters()
+ self.critic_2.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip)
self.critic_optimizer.step()
@@ -339,6 +353,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
diff --git a/skrl/agents/torch/sac/sac_rnn.py b/skrl/agents/torch/sac/sac_rnn.py
index 822f836c..6553d7aa 100644
--- a/skrl/agents/torch/sac/sac_rnn.py
+++ b/skrl/agents/torch/sac/sac_rnn.py
@@ -10,6 +10,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -111,6 +112,16 @@ def __init__(self,
self.checkpoint_modules["target_critic_1"] = self.target_critic_1
self.checkpoint_modules["target_critic_2"] = self.target_critic_2
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.critic_1 is not None:
+ self.critic_1.broadcast_parameters()
+ if self.critic_2 is not None:
+ self.critic_2.broadcast_parameters()
+
if self.target_critic_1 is not None and self.target_critic_2 is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_critic_1.freeze_parameters(True)
@@ -367,6 +378,9 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
+ if config.torch.is_distributed:
+ self.critic_1.reduce_parameters()
+ self.critic_2.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip)
self.critic_optimizer.step()
@@ -381,6 +395,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
diff --git a/skrl/agents/torch/td3/td3.py b/skrl/agents/torch/td3/td3.py
index c276129b..abbcc467 100644
--- a/skrl/agents/torch/td3/td3.py
+++ b/skrl/agents/torch/td3/td3.py
@@ -9,7 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
-from skrl import logger
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -119,6 +119,16 @@ def __init__(self,
self.checkpoint_modules["target_critic_1"] = self.target_critic_1
self.checkpoint_modules["target_critic_2"] = self.target_critic_2
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.critic_1 is not None:
+ self.critic_1.broadcast_parameters()
+ if self.critic_2 is not None:
+ self.critic_2.broadcast_parameters()
+
if self.target_policy is not None and self.target_critic_1 is not None and self.target_critic_2 is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_policy.freeze_parameters(True)
@@ -372,6 +382,9 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
+ if config.torch.is_distributed:
+ self.critic_1.reduce_parameters()
+ self.critic_2.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip)
self.critic_optimizer.step()
@@ -389,6 +402,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
diff --git a/skrl/agents/torch/td3/td3_rnn.py b/skrl/agents/torch/td3/td3_rnn.py
index 5af27e06..abd6a922 100644
--- a/skrl/agents/torch/td3/td3_rnn.py
+++ b/skrl/agents/torch/td3/td3_rnn.py
@@ -9,7 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
-from skrl import logger
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -119,6 +119,16 @@ def __init__(self,
self.checkpoint_modules["target_critic_1"] = self.target_critic_1
self.checkpoint_modules["target_critic_2"] = self.target_critic_2
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.critic_1 is not None:
+ self.critic_1.broadcast_parameters()
+ if self.critic_2 is not None:
+ self.critic_2.broadcast_parameters()
+
if self.target_policy is not None and self.target_critic_1 is not None and self.target_critic_2 is not None:
# freeze target networks with respect to optimizers (update via .update_parameters())
self.target_policy.freeze_parameters(True)
@@ -413,6 +423,9 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (critic)
self.critic_optimizer.zero_grad()
critic_loss.backward()
+ if config.torch.is_distributed:
+ self.critic_1.reduce_parameters()
+ self.critic_2.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip)
self.critic_optimizer.step()
@@ -430,6 +443,8 @@ def _update(self, timestep: int, timesteps: int) -> None:
# optimization step (policy)
self.policy_optimizer.zero_grad()
policy_loss.backward()
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
self.policy_optimizer.step()
diff --git a/skrl/agents/torch/trpo/trpo.py b/skrl/agents/torch/trpo/trpo.py
index b42d9b72..96565bef 100644
--- a/skrl/agents/torch/trpo/trpo.py
+++ b/skrl/agents/torch/trpo/trpo.py
@@ -9,6 +9,7 @@
import torch.nn.functional as F
from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -112,6 +113,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None:
+ self.value.broadcast_parameters()
+
# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
@@ -521,6 +530,9 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor
if restore_policy_flag:
self.policy.update_parameters(self.backup_policy)
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+
# sample mini-batches from memory
sampled_batches = self.memory.sample_all(names=self._tensors_names_value, mini_batches=self._mini_batches)
@@ -542,6 +554,8 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor
# optimization step (value)
self.value_optimizer.zero_grad()
value_loss.backward()
+ if config.torch.is_distributed:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.value.parameters(), self._grad_norm_clip)
self.value_optimizer.step()
diff --git a/skrl/agents/torch/trpo/trpo_rnn.py b/skrl/agents/torch/trpo/trpo_rnn.py
index 429bae78..4b3ad05a 100644
--- a/skrl/agents/torch/trpo/trpo_rnn.py
+++ b/skrl/agents/torch/trpo/trpo_rnn.py
@@ -9,6 +9,7 @@
import torch.nn.functional as F
from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
@@ -112,6 +113,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policy is not None:
+ self.policy.broadcast_parameters()
+ if self.value is not None:
+ self.value.broadcast_parameters()
+
# configuration
self._learning_epochs = self.cfg["learning_epochs"]
self._mini_batches = self.cfg["mini_batches"]
@@ -591,6 +600,9 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor
if restore_policy_flag:
self.policy.update_parameters(self.backup_policy)
+ if config.torch.is_distributed:
+ self.policy.reduce_parameters()
+
# sample mini-batches from memory
sampled_batches = self.memory.sample_all(names=self._tensors_names_value, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length)
@@ -622,6 +634,8 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor
# optimization step (value)
self.value_optimizer.zero_grad()
value_loss.backward()
+ if config.torch.is_distributed:
+ self.value.reduce_parameters()
if self._grad_norm_clip > 0:
nn.utils.clip_grad_norm_(self.value.parameters(), self._grad_norm_clip)
self.value_optimizer.step()
diff --git a/skrl/models/torch/base.py b/skrl/models/torch/base.py
index 757a8ba2..bf90ba8a 100644
--- a/skrl/models/torch/base.py
+++ b/skrl/models/torch/base.py
@@ -7,7 +7,7 @@
import numpy as np
import torch
-from skrl import logger
+from skrl import config, logger
class Model(torch.nn.Module):
@@ -743,3 +743,48 @@ def update_parameters(self, model: torch.nn.Module, polyak: float = 1) -> None:
for parameters, model_parameters in zip(self.parameters(), model.parameters()):
parameters.data.mul_(1 - polyak)
parameters.data.add_(polyak * model_parameters.data)
+
+ def broadcast_parameters(self, rank: int = 0):
+ """Broadcast model parameters to the whole group (e.g.: across all nodes) in distributed runs
+
+ After calling this method, the distributed model will contain the broadcasted parameters from ``rank``
+
+ :param rank: Worker/process rank from which to broadcast model parameters (default: ``0``)
+ :type rank: int
+
+ Example::
+
+ # broadcast model parameter from worker/process with rank 1
+ >>> if config.torch.is_distributed:
+ ... model.update_parameters(source_model, rank=1)
+ """
+ object_list = [self.state_dict()]
+ torch.distributed.broadcast_object_list(object_list, rank)
+ self.load_state_dict(object_list[0])
+
+ def reduce_parameters(self):
+ """Reduce model parameters across all workers/processes in the whole group (e.g.: across all nodes)
+
+ After calling this method, the distributed model parameters will be bitwise identical for all workers/processes
+
+ Example::
+
+ # reduce model parameter across all workers/processes
+ >>> if config.torch.is_distributed:
+ ... model.reduce_parameters()
+ """
+ # batch all_reduce ops: https://github.com/entity-neural-network/incubator/pull/220
+ gradients = []
+ for parameters in self.parameters():
+ if parameters.grad is not None:
+ gradients.append(parameters.grad.view(-1))
+ gradients = torch.cat(gradients)
+
+ torch.distributed.all_reduce(gradients, op=torch.distributed.ReduceOp.SUM)
+
+ offset = 0
+ for parameters in self.parameters():
+ if parameters.grad is not None:
+ parameters.grad.data.copy_(gradients[offset:offset + parameters.numel()] \
+ .view_as(parameters.grad.data) / config.torch.world_size)
+ offset += parameters.numel()
diff --git a/skrl/multi_agents/torch/ippo/ippo.py b/skrl/multi_agents/torch/ippo/ippo.py
index 45913edd..3a1f7932 100644
--- a/skrl/multi_agents/torch/ippo/ippo.py
+++ b/skrl/multi_agents/torch/ippo/ippo.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.memories.torch import Memory
from skrl.models.torch import Model
from skrl.multi_agents.torch import MultiAgent
@@ -109,9 +110,18 @@ def __init__(self,
self.values = {uid: self.models[uid].get("value", None) for uid in self.possible_agents}
for uid in self.possible_agents:
+ # checkpoint models
self.checkpoint_modules[uid]["policy"] = self.policies[uid]
self.checkpoint_modules[uid]["value"] = self.values[uid]
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policies[uid] is not None:
+ self.policies[uid].broadcast_parameters()
+ if self.values[uid] is not None and self.policies[uid] is not self.values[uid]:
+ self.values[uid].broadcast_parameters()
+
# configuration
self._learning_epochs = self._as_dict(self.cfg["learning_epochs"])
self._mini_batches = self._as_dict(self.cfg["mini_batches"])
@@ -437,6 +447,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizers[uid].zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ policy.reduce_parameters()
+ if policy is not value:
+ value.reduce_parameters()
if self._grad_norm_clip[uid] > 0:
if policy is value:
nn.utils.clip_grad_norm_(policy.parameters(), self._grad_norm_clip[uid])
@@ -453,7 +467,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler[uid]:
if isinstance(self.schedulers[uid], KLAdaptiveLR):
- self.schedulers[uid].step(torch.tensor(kl_divergences).mean())
+ self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.schedulers[uid].step()
diff --git a/skrl/multi_agents/torch/mappo/mappo.py b/skrl/multi_agents/torch/mappo/mappo.py
index 98fff05c..ef0e7bc2 100644
--- a/skrl/multi_agents/torch/mappo/mappo.py
+++ b/skrl/multi_agents/torch/mappo/mappo.py
@@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F
+from skrl import config, logger
from skrl.memories.torch import Memory
from skrl.models.torch import Model
from skrl.multi_agents.torch import MultiAgent
@@ -116,9 +117,18 @@ def __init__(self,
self.values = {uid: self.models[uid].get("value", None) for uid in self.possible_agents}
for uid in self.possible_agents:
+ # checkpoint models
self.checkpoint_modules[uid]["policy"] = self.policies[uid]
self.checkpoint_modules[uid]["value"] = self.values[uid]
+ # broadcast models' parameters in distributed runs
+ if config.torch.is_distributed:
+ logger.info(f"Broadcasting models' parameters")
+ if self.policies[uid] is not None:
+ self.policies[uid].broadcast_parameters()
+ if self.values[uid] is not None and self.policies[uid] is not self.values[uid]:
+ self.values[uid].broadcast_parameters()
+
# configuration
self._learning_epochs = self._as_dict(self.cfg["learning_epochs"])
self._mini_batches = self._as_dict(self.cfg["mini_batches"])
@@ -457,6 +467,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizers[uid].zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
+ if config.torch.is_distributed:
+ policy.reduce_parameters()
+ if policy is not value:
+ value.reduce_parameters()
if self._grad_norm_clip[uid] > 0:
if policy is value:
nn.utils.clip_grad_norm_(policy.parameters(), self._grad_norm_clip[uid])
@@ -473,7 +487,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler[uid]:
if isinstance(self.schedulers[uid], KLAdaptiveLR):
- self.schedulers[uid].step(torch.tensor(kl_divergences).mean())
+ self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.schedulers[uid].step()
diff --git a/skrl/resources/schedulers/torch/kl_adaptive.py b/skrl/resources/schedulers/torch/kl_adaptive.py
index 0c97f50f..4d7d5771 100644
--- a/skrl/resources/schedulers/torch/kl_adaptive.py
+++ b/skrl/resources/schedulers/torch/kl_adaptive.py
@@ -1,8 +1,12 @@
from typing import Optional, Union
+from packaging import version
+
import torch
from torch.optim.lr_scheduler import _LRScheduler
+from skrl import config
+
class KLAdaptiveLR(_LRScheduler):
def __init__(self,
@@ -25,6 +29,10 @@ def __init__(self,
This scheduler is only available for PPO at the moment.
Applying it to other agents will not change the learning rate
+ .. note::
+
+ In distributed runs, the learning rate will be reduced and broadcasted across all workers/processes
+
Example::
>>> scheduler = KLAdaptiveLR(optimizer, kl_threshold=0.01)
@@ -50,6 +58,8 @@ def __init__(self,
:param verbose: Verbose mode (default: ``False``)
:type verbose: bool, optional
"""
+ if version.parse(torch.__version__) >= version.parse("2.2"):
+ verbose = "deprecated"
super().__init__(optimizer, last_epoch, verbose)
self.kl_threshold = kl_threshold
@@ -82,10 +92,25 @@ def step(self, kl: Optional[Union[torch.Tensor, float]] = None, epoch: Optional[
:type epoch: int, optional
"""
if kl is not None:
- for group in self.optimizer.param_groups:
+ # reduce (collect from all workers/processes) learning rate in distributed runs
+ if config.torch.is_distributed:
+ torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM)
+ kl /= config.torch.world_size
+
+ for i, group in enumerate(self.optimizer.param_groups):
+ # adjust the learning rate
+ lr = group['lr']
if kl > self.kl_threshold * self._kl_factor:
- group['lr'] = max(group['lr'] / self._lr_factor, self.min_lr)
+ lr = max(lr / self._lr_factor, self.min_lr)
elif kl < self.kl_threshold / self._kl_factor:
- group['lr'] = min(group['lr'] * self._lr_factor, self.max_lr)
+ lr = min(lr * self._lr_factor, self.max_lr)
+
+ # broadcast learning rate in distributed runs
+ if config.torch.is_distributed:
+ lr_tensor = torch.tensor([lr], device=config.torch.device)
+ torch.distributed.broadcast(lr_tensor, 0)
+ lr = lr_tensor.item()
- self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
+ # update value
+ group['lr'] = lr
+ self._last_lr[i] = lr
diff --git a/skrl/trainers/torch/base.py b/skrl/trainers/torch/base.py
index 232e01ee..8784185b 100644
--- a/skrl/trainers/torch/base.py
+++ b/skrl/trainers/torch/base.py
@@ -6,7 +6,7 @@
import torch
-from skrl import logger
+from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.envs.wrappers.torch import Wrapper
@@ -75,6 +75,11 @@ def close_env():
self.env.close()
logger.info("Environment closed")
+ # update trainer configuration to avoid duplicated info/data in distributed runs
+ if config.torch.is_distributed:
+ if config.torch.rank:
+ self.disable_progressbar = True
+
def __str__(self) -> str:
"""Generate a string representation of the trainer
diff --git a/skrl/utils/__init__.py b/skrl/utils/__init__.py
index e67e66cf..6ebdccf1 100644
--- a/skrl/utils/__init__.py
+++ b/skrl/utils/__init__.py
@@ -14,8 +14,14 @@ def set_seed(seed: Optional[int] = None, deterministic: bool = False) -> int:
"""
Set the seed for the random number generators
- Due to NumPy's legacy seeding constraint the seed must be between 0 and 2**32 - 1.
- Otherwise a NumPy exception (``ValueError: Seed must be between 0 and 2**32 - 1``) will be raised
+ .. note::
+
+ In distributed runs, the worker/process seed will be incremented (counting from the defined value) according to its rank
+
+ .. warning::
+
+ Due to NumPy's legacy seeding constraint the seed must be between 0 and 2**32 - 1.
+ Otherwise a NumPy exception (``ValueError: Seed must be between 0 and 2**32 - 1``) will be raised
Modified packages:
@@ -65,8 +71,12 @@ def set_seed(seed: Optional[int] = None, deterministic: bool = False) -> int:
except NotImplementedError:
seed = int(time.time() * 1000)
seed %= 2 ** 31 # NumPy's legacy seeding seed must be between 0 and 2**32 - 1
-
seed = int(seed)
+
+ # set different seeds in distributed runs
+ if config.torch.is_distributed:
+ seed += config.torch.rank
+
logger.info(f"Seed: {seed}")
# numpy