diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b52e4b3..71132f6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). - Define the `environment_info` trainer config to log environment info (PyTorch implementation) - Add support to automatically compute the write and checkpoint intervals and make it the default option - Single forward-pass in shared models +- Distributed multi-GPU and multi-node learning (PyTorch implementation) ### Changed - Update Orbit-related source code and docs to Isaac Lab diff --git a/docs/source/api/agents/a2c.rst b/docs/source/api/agents/a2c.rst index c98f6448..dfb55a02 100644 --- a/docs/source/api/agents/a2c.rst +++ b/docs/source/api/agents/a2c.rst @@ -232,6 +232,10 @@ Support for advanced features is described in the next table - RNN, LSTM, GRU and any other variant - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Distributed + - Single Program Multi Data (SPMD) multi-GPU + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` .. raw:: html diff --git a/docs/source/api/agents/amp.rst b/docs/source/api/agents/amp.rst index c993b67a..e6937f76 100644 --- a/docs/source/api/agents/amp.rst +++ b/docs/source/api/agents/amp.rst @@ -237,6 +237,10 @@ Support for advanced features is described in the next table - \- - .. centered:: :math:`\square` - .. centered:: :math:`\square` + * - Distributed + - Single Program Multi Data (SPMD) multi-GPU + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` .. raw:: html diff --git a/docs/source/api/agents/cem.rst b/docs/source/api/agents/cem.rst index 68245818..1bcbee01 100644 --- a/docs/source/api/agents/cem.rst +++ b/docs/source/api/agents/cem.rst @@ -175,6 +175,10 @@ Support for advanced features is described in the next table - \- - .. centered:: :math:`\square` - .. centered:: :math:`\square` + * - Distributed + - \- + - .. centered:: :math:`\square` + - .. centered:: :math:`\square` .. raw:: html diff --git a/docs/source/api/agents/ddpg.rst b/docs/source/api/agents/ddpg.rst index 7bfff2e3..461fda20 100644 --- a/docs/source/api/agents/ddpg.rst +++ b/docs/source/api/agents/ddpg.rst @@ -236,6 +236,10 @@ Support for advanced features is described in the next table - RNN, LSTM, GRU and any other variant - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Distributed + - Single Program Multi Data (SPMD) multi-GPU + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` .. raw:: html diff --git a/docs/source/api/agents/ddqn.rst b/docs/source/api/agents/ddqn.rst index 91744c33..f2ac0029 100644 --- a/docs/source/api/agents/ddqn.rst +++ b/docs/source/api/agents/ddqn.rst @@ -184,6 +184,10 @@ Support for advanced features is described in the next table - \- - .. centered:: :math:`\square` - .. centered:: :math:`\square` + * - Distributed + - Single Program Multi Data (SPMD) multi-GPU + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` .. raw:: html diff --git a/docs/source/api/agents/dqn.rst b/docs/source/api/agents/dqn.rst index 605e57ad..ed196366 100644 --- a/docs/source/api/agents/dqn.rst +++ b/docs/source/api/agents/dqn.rst @@ -184,6 +184,10 @@ Support for advanced features is described in the next table - \- - .. centered:: :math:`\square` - .. centered:: :math:`\square` + * - Distributed + - Single Program Multi Data (SPMD) multi-GPU + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` .. raw:: html diff --git a/docs/source/api/agents/ppo.rst b/docs/source/api/agents/ppo.rst index 0bf7b37b..f3c84f20 100644 --- a/docs/source/api/agents/ppo.rst +++ b/docs/source/api/agents/ppo.rst @@ -248,6 +248,10 @@ Support for advanced features is described in the next table - RNN, LSTM, GRU and any other variant - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Distributed + - Single Program Multi Data (SPMD) multi-GPU + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` .. raw:: html diff --git a/docs/source/api/agents/rpo.rst b/docs/source/api/agents/rpo.rst index 61947dff..f769e9c7 100644 --- a/docs/source/api/agents/rpo.rst +++ b/docs/source/api/agents/rpo.rst @@ -285,6 +285,10 @@ Support for advanced features is described in the next table - RNN, LSTM, GRU and any other variant - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Distributed + - Single Program Multi Data (SPMD) multi-GPU + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` .. raw:: html diff --git a/docs/source/api/agents/sac.rst b/docs/source/api/agents/sac.rst index c55720cc..2c17fc3a 100644 --- a/docs/source/api/agents/sac.rst +++ b/docs/source/api/agents/sac.rst @@ -244,6 +244,10 @@ Support for advanced features is described in the next table - RNN, LSTM, GRU and any other variant - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Distributed + - Single Program Multi Data (SPMD) multi-GPU + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` .. raw:: html diff --git a/docs/source/api/agents/td3.rst b/docs/source/api/agents/td3.rst index 4c494f5c..5d54ae4f 100644 --- a/docs/source/api/agents/td3.rst +++ b/docs/source/api/agents/td3.rst @@ -258,6 +258,10 @@ Support for advanced features is described in the next table - RNN, LSTM, GRU and any other variant - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Distributed + - Single Program Multi Data (SPMD) multi-GPU + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` .. raw:: html diff --git a/docs/source/api/agents/trpo.rst b/docs/source/api/agents/trpo.rst index 21460fe9..c482bf87 100644 --- a/docs/source/api/agents/trpo.rst +++ b/docs/source/api/agents/trpo.rst @@ -282,6 +282,10 @@ Support for advanced features is described in the next table - RNN, LSTM, GRU and any other variant - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\square` + * - Distributed + - Single Program Multi Data (SPMD) multi-GPU + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` .. raw:: html diff --git a/docs/source/api/config/frameworks.rst b/docs/source/api/config/frameworks.rst index eff8bb77..f25e492f 100644 --- a/docs/source/api/config/frameworks.rst +++ b/docs/source/api/config/frameworks.rst @@ -7,6 +7,65 @@ Configurations for behavior modification of Machine Learning (ML) frameworks.

+PyTorch +------- + +PyTorch specific configuration + +.. raw:: html + +
+ +API +^^^ + +.. py:data:: skrl.config.torch.device + :type: torch.device + :value: "cuda:${LOCAL_RANK}" | "cpu" + + Default device + + The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise + +.. py:data:: skrl.config.local_rank + :type: int + :value: 0 + + The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node) + + This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist). + See `torch.distributed `_ for more details + +.. py:data:: skrl.config.rank + :type: int + :value: 0 + + The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes) + + This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist). + See `torch.distributed `_ for more details + +.. py:data:: skrl.config.world_size + :type: int + :value: 1 + + The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes) + + This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist). + See `torch.distributed `_ for more details + +.. py:data:: skrl.config.is_distributed + :type: bool + :value: False + + Whether if running in a distributed environment + + This property is ``True`` when the PyTorch's distributed environment variable ``WORLD_SIZE > 1`` + +.. raw:: html + +
+ JAX --- diff --git a/docs/source/api/multi_agents/ippo.rst b/docs/source/api/multi_agents/ippo.rst index 9a259326..862fd2b9 100644 --- a/docs/source/api/multi_agents/ippo.rst +++ b/docs/source/api/multi_agents/ippo.rst @@ -239,6 +239,10 @@ Support for advanced features is described in the next table - \- - .. centered:: :math:`\square` - .. centered:: :math:`\square` + * - Distributed + - Single Program Multi Data (SPMD) multi-GPU + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` .. raw:: html diff --git a/docs/source/api/multi_agents/mappo.rst b/docs/source/api/multi_agents/mappo.rst index c875ac6a..7ae18de3 100644 --- a/docs/source/api/multi_agents/mappo.rst +++ b/docs/source/api/multi_agents/mappo.rst @@ -240,6 +240,10 @@ Support for advanced features is described in the next table - \- - .. centered:: :math:`\square` - .. centered:: :math:`\square` + * - Distributed + - Single Program Multi Data (SPMD) multi-GPU + - .. centered:: :math:`\blacksquare` + - .. centered:: :math:`\square` .. raw:: html diff --git a/docs/source/api/utils.rst b/docs/source/api/utils.rst index a86fb2da..ff050aaa 100644 --- a/docs/source/api/utils.rst +++ b/docs/source/api/utils.rst @@ -25,7 +25,7 @@ A set of utilities and configurations for managing an RL setup is provided as pa - .. centered:: |_4| |pytorch| |_4| - .. centered:: |_4| |jax| |_4| * - :doc:`ML frameworks ` configuration |_5| |_5| |_5| |_5| |_5| |_2| - - .. centered:: :math:`\square` + - .. centered:: :math:`\blacksquare` - .. centered:: :math:`\blacksquare` .. list-table:: diff --git a/skrl/__init__.py b/skrl/__init__.py index 068206ee..e3e4fa0c 100644 --- a/skrl/__init__.py +++ b/skrl/__init__.py @@ -1,6 +1,7 @@ from typing import Union import logging +import os import sys import numpy as np @@ -43,6 +44,69 @@ class _Config(object): def __init__(self) -> None: """Machine learning framework specific configuration """ + + class PyTorch(object): + def __init__(self) -> None: + """PyTorch configuration + """ + self._device = None + # torch.distributed config + self._local_rank = int(os.getenv("LOCAL_RANK", "0")) + self._rank = int(os.getenv("RANK", "0")) + self._world_size = int(os.getenv("WORLD_SIZE", "1")) + self._is_distributed = self._world_size > 1 + + @property + def local_rank(self) -> int: + """The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node) + + This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist) + """ + return self._local_rank + + @property + def rank(self) -> int: + """The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes) + + This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist) + """ + return self._rank + + @property + def world_size(self) -> int: + """The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes) + + This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist) + """ + return self._world_size + + @property + def is_distributed(self) -> bool: + """Whether if running in a distributed environment + + This property is ``True`` when the PyTorch's distributed environment variable ``WORLD_SIZE > 1`` + """ + return self._is_distributed + + @property + def device(self) -> "torch.device": + """Default device + + The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) + if CUDA is available, ``cpu`` otherwise + """ + try: + import torch + if self._device is None: + return torch.device(f"cuda:{self._local_rank}" if torch.cuda.is_available() else "cpu") + return torch.device(self._device) + except ImportError: + return self._device + + @device.setter + def device(self, device: Union[str, "torch.device"]) -> None: + self._device = device + class JAX(object): def __init__(self) -> None: """JAX configuration @@ -91,5 +155,6 @@ def key(self, value: Union[int, "jax.Array"]) -> None: self._key = value self.jax = JAX() + self.torch = PyTorch() config = _Config() diff --git a/skrl/agents/torch/a2c/a2c.py b/skrl/agents/torch/a2c/a2c.py index 84f28846..fd495516 100644 --- a/skrl/agents/torch/a2c/a2c.py +++ b/skrl/agents/torch/a2c/a2c.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -104,6 +105,14 @@ def __init__(self, self.checkpoint_modules["policy"] = self.policy self.checkpoint_modules["value"] = self.value + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.value is not None and self.policy is not self.value: + self.value.broadcast_parameters() + # configuration self._mini_batches = self.cfg["mini_batches"] self._rollouts = self.cfg["rollouts"] @@ -391,6 +400,10 @@ def compute_gae(rewards: torch.Tensor, # optimization step self.optimizer.zero_grad() (policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: + self.policy.reduce_parameters() + if self.policy is not self.value: + self.value.reduce_parameters() if self._grad_norm_clip > 0: if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) @@ -407,7 +420,7 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler: if isinstance(self.scheduler, KLAdaptiveLR): - self.scheduler.step(torch.tensor(kl_divergences).mean()) + self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean()) else: self.scheduler.step() diff --git a/skrl/agents/torch/a2c/a2c_rnn.py b/skrl/agents/torch/a2c/a2c_rnn.py index 1053b3de..8cecc41e 100644 --- a/skrl/agents/torch/a2c/a2c_rnn.py +++ b/skrl/agents/torch/a2c/a2c_rnn.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -104,6 +105,14 @@ def __init__(self, self.checkpoint_modules["policy"] = self.policy self.checkpoint_modules["value"] = self.value + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.value is not None and self.policy is not self.value: + self.value.broadcast_parameters() + # configuration self._mini_batches = self.cfg["mini_batches"] self._rollouts = self.cfg["rollouts"] @@ -462,6 +471,10 @@ def compute_gae(rewards: torch.Tensor, # optimization step self.optimizer.zero_grad() (policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: + self.policy.reduce_parameters() + if self.policy is not self.value: + self.value.reduce_parameters() if self._grad_norm_clip > 0: if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) @@ -478,7 +491,7 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler: if isinstance(self.scheduler, KLAdaptiveLR): - self.scheduler.step(torch.tensor(kl_divergences).mean()) + self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean()) else: self.scheduler.step() diff --git a/skrl/agents/torch/amp/amp.py b/skrl/agents/torch/amp/amp.py index 743fa804..1a648c22 100644 --- a/skrl/agents/torch/amp/amp.py +++ b/skrl/agents/torch/amp/amp.py @@ -10,6 +10,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -147,6 +148,16 @@ def __init__(self, self.checkpoint_modules["value"] = self.value self.checkpoint_modules["discriminator"] = self.discriminator + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.value is not None: + self.value.broadcast_parameters() + if self.discriminator is not None: + self.discriminator.broadcast_parameters() + # configuration self._learning_epochs = self.cfg["learning_epochs"] self._mini_batches = self.cfg["mini_batches"] @@ -554,6 +565,10 @@ def compute_gae(rewards: torch.Tensor, # optimization step self.optimizer.zero_grad() (policy_loss + entropy_loss + value_loss + discriminator_loss).backward() + if config.torch.is_distributed: + self.policy.reduce_parameters() + self.value.reduce_parameters() + self.discriminator.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(itertools.chain(self.policy.parameters(), self.value.parameters(), @@ -571,7 +586,7 @@ def compute_gae(rewards: torch.Tensor, if self._learning_rate_scheduler: self.scheduler.step() - # update AMP repaly buffer + # update AMP replay buffer self.reply_buffer.add_samples(states=amp_states.view(-1, amp_states.shape[-1])) # record data diff --git a/skrl/agents/torch/base.py b/skrl/agents/torch/base.py index db95e020..d23f27cc 100644 --- a/skrl/agents/torch/base.py +++ b/skrl/agents/torch/base.py @@ -11,7 +11,7 @@ import torch from torch.utils.tensorboard import SummaryWriter -from skrl import logger +from skrl import config, logger from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -85,6 +85,12 @@ def __init__(self, experiment_name = "{}_{}".format(datetime.datetime.now().strftime("%y-%m-%d_%H-%M-%S-%f"), self.__class__.__name__) self.experiment_dir = os.path.join(directory, experiment_name) + # set up distributed runs + if config.torch.is_distributed: + logger.info(f"Distributed (rank: {config.torch.rank}, local rank: {config.torch.local_rank}, world size: {config.torch.world_size})") + torch.distributed.init_process_group("nccl", rank=config.torch.rank, world_size=config.torch.world_size) + torch.cuda.set_device(config.torch.local_rank) + def __str__(self) -> str: """Generate a representation of the agent as string @@ -129,26 +135,33 @@ def init(self, trainer_cfg: Optional[Mapping[str, Any]] = None) -> None: """Initialize the agent This method should be called before the agent is used. - It will initialize the TensoBoard writer (and optionally Weights & Biases) and create the checkpoints directory + It will initialize the TensorBoard writer (and optionally Weights & Biases) and create the checkpoints directory :param trainer_cfg: Trainer configuration :type trainer_cfg: dict, optional """ trainer_cfg = trainer_cfg if trainer_cfg is not None else {} + + # update agent configuration to avoid duplicated logging/checking in distributed runs + if config.torch.is_distributed and config.torch.rank: + self.write_interval = 0 + self.checkpoint_interval = 0 + # TODO: disable wandb + # setup Weights & Biases if self.cfg.get("experiment", {}).get("wandb", False): - # save experiment config + # save experiment configuration try: models_cfg = {k: v.net._modules for (k, v) in self.models.items()} except AttributeError: models_cfg = {k: v._modules for (k, v) in self.models.items()} - config={**self.cfg, **trainer_cfg, **models_cfg} + wandb_config={**self.cfg, **trainer_cfg, **models_cfg} # set default values wandb_kwargs = copy.deepcopy(self.cfg.get("experiment", {}).get("wandb_kwargs", {})) wandb_kwargs.setdefault("name", os.path.split(self.experiment_dir)[-1]) wandb_kwargs.setdefault("sync_tensorboard", True) wandb_kwargs.setdefault("config", {}) - wandb_kwargs["config"].update(config) + wandb_kwargs["config"].update(wandb_config) # init Weights & Biases import wandb wandb.init(**wandb_kwargs) @@ -386,7 +399,7 @@ def migrate(self, name_map: Mapping[str, Mapping[str, str]] = {}, auto_mapping: bool = True, verbose: bool = False) -> bool: - """Migrate the specified extrernal checkpoint to the current agent + """Migrate the specified external checkpoint to the current agent The final storage device is determined by the constructor of the agent. Only files generated by the *rl_games* library are supported at the moment diff --git a/skrl/agents/torch/ddpg/ddpg.py b/skrl/agents/torch/ddpg/ddpg.py index ddcf0280..88f6ccc9 100644 --- a/skrl/agents/torch/ddpg/ddpg.py +++ b/skrl/agents/torch/ddpg/ddpg.py @@ -8,6 +8,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -109,6 +110,14 @@ def __init__(self, self.checkpoint_modules["critic"] = self.critic self.checkpoint_modules["target_critic"] = self.target_critic + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.critic is not None: + self.critic.broadcast_parameters() + if self.target_policy is not None and self.target_critic is not None: # freeze target networks with respect to optimizers (update via .update_parameters()) self.target_policy.freeze_parameters(True) @@ -341,6 +350,8 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (critic) self.critic_optimizer.zero_grad() critic_loss.backward() + if config.torch.is_distributed: + self.critic.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(self.critic.parameters(), self._grad_norm_clip) self.critic_optimizer.step() @@ -354,6 +365,8 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (policy) self.policy_optimizer.zero_grad() policy_loss.backward() + if config.torch.is_distributed: + self.policy.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) self.policy_optimizer.step() diff --git a/skrl/agents/torch/ddpg/ddpg_rnn.py b/skrl/agents/torch/ddpg/ddpg_rnn.py index 400873ef..436184c1 100644 --- a/skrl/agents/torch/ddpg/ddpg_rnn.py +++ b/skrl/agents/torch/ddpg/ddpg_rnn.py @@ -8,6 +8,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -109,6 +110,14 @@ def __init__(self, self.checkpoint_modules["critic"] = self.critic self.checkpoint_modules["target_critic"] = self.target_critic + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.critic is not None: + self.critic.broadcast_parameters() + if self.target_policy is not None and self.target_critic is not None: # freeze target networks with respect to optimizers (update via .update_parameters()) self.target_policy.freeze_parameters(True) @@ -382,6 +391,8 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (critic) self.critic_optimizer.zero_grad() critic_loss.backward() + if config.torch.is_distributed: + self.critic.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(self.critic.parameters(), self._grad_norm_clip) self.critic_optimizer.step() @@ -395,6 +406,8 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (policy) self.policy_optimizer.zero_grad() policy_loss.backward() + if config.torch.is_distributed: + self.policy.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) self.policy_optimizer.step() diff --git a/skrl/agents/torch/dqn/ddqn.py b/skrl/agents/torch/dqn/ddqn.py index ac18c750..8d181ac8 100644 --- a/skrl/agents/torch/dqn/ddqn.py +++ b/skrl/agents/torch/dqn/ddqn.py @@ -8,6 +8,7 @@ import torch import torch.nn.functional as F +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -104,6 +105,12 @@ def __init__(self, self.checkpoint_modules["q_network"] = self.q_network self.checkpoint_modules["target_q_network"] = self.target_q_network + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.q_network is not None: + self.q_network.broadcast_parameters() + if self.target_q_network is not None: # freeze target networks with respect to optimizers (update via .update_parameters()) self.target_q_network.freeze_parameters(True) @@ -303,6 +310,8 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimize Q-network self.optimizer.zero_grad() q_network_loss.backward() + if config.torch.is_distributed: + self.q_network.reduce_parameters() self.optimizer.step() # update target network diff --git a/skrl/agents/torch/dqn/dqn.py b/skrl/agents/torch/dqn/dqn.py index c23c611f..c7f4f709 100644 --- a/skrl/agents/torch/dqn/dqn.py +++ b/skrl/agents/torch/dqn/dqn.py @@ -8,6 +8,7 @@ import torch import torch.nn.functional as F +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -104,6 +105,12 @@ def __init__(self, self.checkpoint_modules["q_network"] = self.q_network self.checkpoint_modules["target_q_network"] = self.target_q_network + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.q_network is not None: + self.q_network.broadcast_parameters() + if self.target_q_network is not None: # freeze target networks with respect to optimizers (update via .update_parameters()) self.target_q_network.freeze_parameters(True) @@ -303,6 +310,8 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimize Q-network self.optimizer.zero_grad() q_network_loss.backward() + if config.torch.is_distributed: + self.q_network.reduce_parameters() self.optimizer.step() # update target network diff --git a/skrl/agents/torch/ppo/ppo.py b/skrl/agents/torch/ppo/ppo.py index 7e1dc24f..cb68ac2a 100644 --- a/skrl/agents/torch/ppo/ppo.py +++ b/skrl/agents/torch/ppo/ppo.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -111,6 +112,14 @@ def __init__(self, self.checkpoint_modules["policy"] = self.policy self.checkpoint_modules["value"] = self.value + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.value is not None and self.policy is not self.value: + self.value.broadcast_parameters() + # configuration self._learning_epochs = self.cfg["learning_epochs"] self._mini_batches = self.cfg["mini_batches"] @@ -418,6 +427,10 @@ def compute_gae(rewards: torch.Tensor, # optimization step self.optimizer.zero_grad() (policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: + self.policy.reduce_parameters() + if self.policy is not self.value: + self.value.reduce_parameters() if self._grad_norm_clip > 0: if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) @@ -434,7 +447,7 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler: if isinstance(self.scheduler, KLAdaptiveLR): - self.scheduler.step(torch.tensor(kl_divergences).mean()) + self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean()) else: self.scheduler.step() diff --git a/skrl/agents/torch/ppo/ppo_rnn.py b/skrl/agents/torch/ppo/ppo_rnn.py index c74ca937..1ec4e7ca 100644 --- a/skrl/agents/torch/ppo/ppo_rnn.py +++ b/skrl/agents/torch/ppo/ppo_rnn.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -111,6 +112,14 @@ def __init__(self, self.checkpoint_modules["policy"] = self.policy self.checkpoint_modules["value"] = self.value + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.value is not None and self.policy is not self.value: + self.value.broadcast_parameters() + # configuration self._learning_epochs = self.cfg["learning_epochs"] self._mini_batches = self.cfg["mini_batches"] @@ -490,6 +499,10 @@ def compute_gae(rewards: torch.Tensor, # optimization step self.optimizer.zero_grad() (policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: + self.policy.reduce_parameters() + if self.policy is not self.value: + self.value.reduce_parameters() if self._grad_norm_clip > 0: if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) @@ -506,7 +519,7 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler: if isinstance(self.scheduler, KLAdaptiveLR): - self.scheduler.step(torch.tensor(kl_divergences).mean()) + self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean()) else: self.scheduler.step() diff --git a/skrl/agents/torch/rpo/rpo.py b/skrl/agents/torch/rpo/rpo.py index 93a59435..dfb6fa9f 100644 --- a/skrl/agents/torch/rpo/rpo.py +++ b/skrl/agents/torch/rpo/rpo.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -112,6 +113,14 @@ def __init__(self, self.checkpoint_modules["policy"] = self.policy self.checkpoint_modules["value"] = self.value + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.value is not None and self.policy is not self.value: + self.value.broadcast_parameters() + # configuration self._learning_epochs = self.cfg["learning_epochs"] self._mini_batches = self.cfg["mini_batches"] @@ -420,6 +429,10 @@ def compute_gae(rewards: torch.Tensor, # optimization step self.optimizer.zero_grad() (policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: + self.policy.reduce_parameters() + if self.policy is not self.value: + self.value.reduce_parameters() if self._grad_norm_clip > 0: if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) @@ -436,7 +449,7 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler: if isinstance(self.scheduler, KLAdaptiveLR): - self.scheduler.step(torch.tensor(kl_divergences).mean()) + self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean()) else: self.scheduler.step() diff --git a/skrl/agents/torch/rpo/rpo_rnn.py b/skrl/agents/torch/rpo/rpo_rnn.py index f6d507f9..94c7c27a 100644 --- a/skrl/agents/torch/rpo/rpo_rnn.py +++ b/skrl/agents/torch/rpo/rpo_rnn.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -112,6 +113,14 @@ def __init__(self, self.checkpoint_modules["policy"] = self.policy self.checkpoint_modules["value"] = self.value + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.value is not None and self.policy is not self.value: + self.value.broadcast_parameters() + # configuration self._learning_epochs = self.cfg["learning_epochs"] self._mini_batches = self.cfg["mini_batches"] @@ -492,6 +501,10 @@ def compute_gae(rewards: torch.Tensor, # optimization step self.optimizer.zero_grad() (policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: + self.policy.reduce_parameters() + if self.policy is not self.value: + self.value.reduce_parameters() if self._grad_norm_clip > 0: if self.policy is self.value: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) @@ -508,7 +521,7 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler: if isinstance(self.scheduler, KLAdaptiveLR): - self.scheduler.step(torch.tensor(kl_divergences).mean()) + self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean()) else: self.scheduler.step() diff --git a/skrl/agents/torch/sac/sac.py b/skrl/agents/torch/sac/sac.py index d6b96270..dc8678a6 100644 --- a/skrl/agents/torch/sac/sac.py +++ b/skrl/agents/torch/sac/sac.py @@ -10,6 +10,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -111,6 +112,16 @@ def __init__(self, self.checkpoint_modules["target_critic_1"] = self.target_critic_1 self.checkpoint_modules["target_critic_2"] = self.target_critic_2 + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.critic_1 is not None: + self.critic_1.broadcast_parameters() + if self.critic_2 is not None: + self.critic_2.broadcast_parameters() + if self.target_critic_1 is not None and self.target_critic_2 is not None: # freeze target networks with respect to optimizers (update via .update_parameters()) self.target_critic_1.freeze_parameters(True) @@ -325,6 +336,9 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (critic) self.critic_optimizer.zero_grad() critic_loss.backward() + if config.torch.is_distributed: + self.critic_1.reduce_parameters() + self.critic_2.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip) self.critic_optimizer.step() @@ -339,6 +353,8 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (policy) self.policy_optimizer.zero_grad() policy_loss.backward() + if config.torch.is_distributed: + self.policy.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) self.policy_optimizer.step() diff --git a/skrl/agents/torch/sac/sac_rnn.py b/skrl/agents/torch/sac/sac_rnn.py index 822f836c..6553d7aa 100644 --- a/skrl/agents/torch/sac/sac_rnn.py +++ b/skrl/agents/torch/sac/sac_rnn.py @@ -10,6 +10,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -111,6 +112,16 @@ def __init__(self, self.checkpoint_modules["target_critic_1"] = self.target_critic_1 self.checkpoint_modules["target_critic_2"] = self.target_critic_2 + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.critic_1 is not None: + self.critic_1.broadcast_parameters() + if self.critic_2 is not None: + self.critic_2.broadcast_parameters() + if self.target_critic_1 is not None and self.target_critic_2 is not None: # freeze target networks with respect to optimizers (update via .update_parameters()) self.target_critic_1.freeze_parameters(True) @@ -367,6 +378,9 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (critic) self.critic_optimizer.zero_grad() critic_loss.backward() + if config.torch.is_distributed: + self.critic_1.reduce_parameters() + self.critic_2.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip) self.critic_optimizer.step() @@ -381,6 +395,8 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (policy) self.policy_optimizer.zero_grad() policy_loss.backward() + if config.torch.is_distributed: + self.policy.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) self.policy_optimizer.step() diff --git a/skrl/agents/torch/td3/td3.py b/skrl/agents/torch/td3/td3.py index c276129b..abbcc467 100644 --- a/skrl/agents/torch/td3/td3.py +++ b/skrl/agents/torch/td3/td3.py @@ -9,7 +9,7 @@ import torch.nn as nn import torch.nn.functional as F -from skrl import logger +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -119,6 +119,16 @@ def __init__(self, self.checkpoint_modules["target_critic_1"] = self.target_critic_1 self.checkpoint_modules["target_critic_2"] = self.target_critic_2 + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.critic_1 is not None: + self.critic_1.broadcast_parameters() + if self.critic_2 is not None: + self.critic_2.broadcast_parameters() + if self.target_policy is not None and self.target_critic_1 is not None and self.target_critic_2 is not None: # freeze target networks with respect to optimizers (update via .update_parameters()) self.target_policy.freeze_parameters(True) @@ -372,6 +382,9 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (critic) self.critic_optimizer.zero_grad() critic_loss.backward() + if config.torch.is_distributed: + self.critic_1.reduce_parameters() + self.critic_2.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip) self.critic_optimizer.step() @@ -389,6 +402,8 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (policy) self.policy_optimizer.zero_grad() policy_loss.backward() + if config.torch.is_distributed: + self.policy.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) self.policy_optimizer.step() diff --git a/skrl/agents/torch/td3/td3_rnn.py b/skrl/agents/torch/td3/td3_rnn.py index 5af27e06..abd6a922 100644 --- a/skrl/agents/torch/td3/td3_rnn.py +++ b/skrl/agents/torch/td3/td3_rnn.py @@ -9,7 +9,7 @@ import torch.nn as nn import torch.nn.functional as F -from skrl import logger +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -119,6 +119,16 @@ def __init__(self, self.checkpoint_modules["target_critic_1"] = self.target_critic_1 self.checkpoint_modules["target_critic_2"] = self.target_critic_2 + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.critic_1 is not None: + self.critic_1.broadcast_parameters() + if self.critic_2 is not None: + self.critic_2.broadcast_parameters() + if self.target_policy is not None and self.target_critic_1 is not None and self.target_critic_2 is not None: # freeze target networks with respect to optimizers (update via .update_parameters()) self.target_policy.freeze_parameters(True) @@ -413,6 +423,9 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (critic) self.critic_optimizer.zero_grad() critic_loss.backward() + if config.torch.is_distributed: + self.critic_1.reduce_parameters() + self.critic_2.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(itertools.chain(self.critic_1.parameters(), self.critic_2.parameters()), self._grad_norm_clip) self.critic_optimizer.step() @@ -430,6 +443,8 @@ def _update(self, timestep: int, timesteps: int) -> None: # optimization step (policy) self.policy_optimizer.zero_grad() policy_loss.backward() + if config.torch.is_distributed: + self.policy.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip) self.policy_optimizer.step() diff --git a/skrl/agents/torch/trpo/trpo.py b/skrl/agents/torch/trpo/trpo.py index b42d9b72..96565bef 100644 --- a/skrl/agents/torch/trpo/trpo.py +++ b/skrl/agents/torch/trpo/trpo.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -112,6 +113,14 @@ def __init__(self, self.checkpoint_modules["policy"] = self.policy self.checkpoint_modules["value"] = self.value + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.value is not None: + self.value.broadcast_parameters() + # configuration self._learning_epochs = self.cfg["learning_epochs"] self._mini_batches = self.cfg["mini_batches"] @@ -521,6 +530,9 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor if restore_policy_flag: self.policy.update_parameters(self.backup_policy) + if config.torch.is_distributed: + self.policy.reduce_parameters() + # sample mini-batches from memory sampled_batches = self.memory.sample_all(names=self._tensors_names_value, mini_batches=self._mini_batches) @@ -542,6 +554,8 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor # optimization step (value) self.value_optimizer.zero_grad() value_loss.backward() + if config.torch.is_distributed: + self.value.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(self.value.parameters(), self._grad_norm_clip) self.value_optimizer.step() diff --git a/skrl/agents/torch/trpo/trpo_rnn.py b/skrl/agents/torch/trpo/trpo_rnn.py index 429bae78..4b3ad05a 100644 --- a/skrl/agents/torch/trpo/trpo_rnn.py +++ b/skrl/agents/torch/trpo/trpo_rnn.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from torch.nn.utils.convert_parameters import parameters_to_vector, vector_to_parameters +from skrl import config, logger from skrl.agents.torch import Agent from skrl.memories.torch import Memory from skrl.models.torch import Model @@ -112,6 +113,14 @@ def __init__(self, self.checkpoint_modules["policy"] = self.policy self.checkpoint_modules["value"] = self.value + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policy is not None: + self.policy.broadcast_parameters() + if self.value is not None: + self.value.broadcast_parameters() + # configuration self._learning_epochs = self.cfg["learning_epochs"] self._mini_batches = self.cfg["mini_batches"] @@ -591,6 +600,9 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor if restore_policy_flag: self.policy.update_parameters(self.backup_policy) + if config.torch.is_distributed: + self.policy.reduce_parameters() + # sample mini-batches from memory sampled_batches = self.memory.sample_all(names=self._tensors_names_value, mini_batches=self._mini_batches, sequence_length=self._rnn_sequence_length) @@ -622,6 +634,8 @@ def kl_divergence(policy_1: Model, policy_2: Model, states: torch.Tensor) -> tor # optimization step (value) self.value_optimizer.zero_grad() value_loss.backward() + if config.torch.is_distributed: + self.value.reduce_parameters() if self._grad_norm_clip > 0: nn.utils.clip_grad_norm_(self.value.parameters(), self._grad_norm_clip) self.value_optimizer.step() diff --git a/skrl/models/torch/base.py b/skrl/models/torch/base.py index 757a8ba2..bf90ba8a 100644 --- a/skrl/models/torch/base.py +++ b/skrl/models/torch/base.py @@ -7,7 +7,7 @@ import numpy as np import torch -from skrl import logger +from skrl import config, logger class Model(torch.nn.Module): @@ -743,3 +743,48 @@ def update_parameters(self, model: torch.nn.Module, polyak: float = 1) -> None: for parameters, model_parameters in zip(self.parameters(), model.parameters()): parameters.data.mul_(1 - polyak) parameters.data.add_(polyak * model_parameters.data) + + def broadcast_parameters(self, rank: int = 0): + """Broadcast model parameters to the whole group (e.g.: across all nodes) in distributed runs + + After calling this method, the distributed model will contain the broadcasted parameters from ``rank`` + + :param rank: Worker/process rank from which to broadcast model parameters (default: ``0``) + :type rank: int + + Example:: + + # broadcast model parameter from worker/process with rank 1 + >>> if config.torch.is_distributed: + ... model.update_parameters(source_model, rank=1) + """ + object_list = [self.state_dict()] + torch.distributed.broadcast_object_list(object_list, rank) + self.load_state_dict(object_list[0]) + + def reduce_parameters(self): + """Reduce model parameters across all workers/processes in the whole group (e.g.: across all nodes) + + After calling this method, the distributed model parameters will be bitwise identical for all workers/processes + + Example:: + + # reduce model parameter across all workers/processes + >>> if config.torch.is_distributed: + ... model.reduce_parameters() + """ + # batch all_reduce ops: https://github.com/entity-neural-network/incubator/pull/220 + gradients = [] + for parameters in self.parameters(): + if parameters.grad is not None: + gradients.append(parameters.grad.view(-1)) + gradients = torch.cat(gradients) + + torch.distributed.all_reduce(gradients, op=torch.distributed.ReduceOp.SUM) + + offset = 0 + for parameters in self.parameters(): + if parameters.grad is not None: + parameters.grad.data.copy_(gradients[offset:offset + parameters.numel()] \ + .view_as(parameters.grad.data) / config.torch.world_size) + offset += parameters.numel() diff --git a/skrl/multi_agents/torch/ippo/ippo.py b/skrl/multi_agents/torch/ippo/ippo.py index 45913edd..3a1f7932 100644 --- a/skrl/multi_agents/torch/ippo/ippo.py +++ b/skrl/multi_agents/torch/ippo/ippo.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.memories.torch import Memory from skrl.models.torch import Model from skrl.multi_agents.torch import MultiAgent @@ -109,9 +110,18 @@ def __init__(self, self.values = {uid: self.models[uid].get("value", None) for uid in self.possible_agents} for uid in self.possible_agents: + # checkpoint models self.checkpoint_modules[uid]["policy"] = self.policies[uid] self.checkpoint_modules[uid]["value"] = self.values[uid] + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policies[uid] is not None: + self.policies[uid].broadcast_parameters() + if self.values[uid] is not None and self.policies[uid] is not self.values[uid]: + self.values[uid].broadcast_parameters() + # configuration self._learning_epochs = self._as_dict(self.cfg["learning_epochs"]) self._mini_batches = self._as_dict(self.cfg["mini_batches"]) @@ -437,6 +447,10 @@ def compute_gae(rewards: torch.Tensor, # optimization step self.optimizers[uid].zero_grad() (policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: + policy.reduce_parameters() + if policy is not value: + value.reduce_parameters() if self._grad_norm_clip[uid] > 0: if policy is value: nn.utils.clip_grad_norm_(policy.parameters(), self._grad_norm_clip[uid]) @@ -453,7 +467,7 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler[uid]: if isinstance(self.schedulers[uid], KLAdaptiveLR): - self.schedulers[uid].step(torch.tensor(kl_divergences).mean()) + self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean()) else: self.schedulers[uid].step() diff --git a/skrl/multi_agents/torch/mappo/mappo.py b/skrl/multi_agents/torch/mappo/mappo.py index 98fff05c..ef0e7bc2 100644 --- a/skrl/multi_agents/torch/mappo/mappo.py +++ b/skrl/multi_agents/torch/mappo/mappo.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F +from skrl import config, logger from skrl.memories.torch import Memory from skrl.models.torch import Model from skrl.multi_agents.torch import MultiAgent @@ -116,9 +117,18 @@ def __init__(self, self.values = {uid: self.models[uid].get("value", None) for uid in self.possible_agents} for uid in self.possible_agents: + # checkpoint models self.checkpoint_modules[uid]["policy"] = self.policies[uid] self.checkpoint_modules[uid]["value"] = self.values[uid] + # broadcast models' parameters in distributed runs + if config.torch.is_distributed: + logger.info(f"Broadcasting models' parameters") + if self.policies[uid] is not None: + self.policies[uid].broadcast_parameters() + if self.values[uid] is not None and self.policies[uid] is not self.values[uid]: + self.values[uid].broadcast_parameters() + # configuration self._learning_epochs = self._as_dict(self.cfg["learning_epochs"]) self._mini_batches = self._as_dict(self.cfg["mini_batches"]) @@ -457,6 +467,10 @@ def compute_gae(rewards: torch.Tensor, # optimization step self.optimizers[uid].zero_grad() (policy_loss + entropy_loss + value_loss).backward() + if config.torch.is_distributed: + policy.reduce_parameters() + if policy is not value: + value.reduce_parameters() if self._grad_norm_clip[uid] > 0: if policy is value: nn.utils.clip_grad_norm_(policy.parameters(), self._grad_norm_clip[uid]) @@ -473,7 +487,7 @@ def compute_gae(rewards: torch.Tensor, # update learning rate if self._learning_rate_scheduler[uid]: if isinstance(self.schedulers[uid], KLAdaptiveLR): - self.schedulers[uid].step(torch.tensor(kl_divergences).mean()) + self.schedulers[uid].step(torch.tensor(kl_divergences, device=self.device).mean()) else: self.schedulers[uid].step() diff --git a/skrl/resources/schedulers/torch/kl_adaptive.py b/skrl/resources/schedulers/torch/kl_adaptive.py index 0c97f50f..4d7d5771 100644 --- a/skrl/resources/schedulers/torch/kl_adaptive.py +++ b/skrl/resources/schedulers/torch/kl_adaptive.py @@ -1,8 +1,12 @@ from typing import Optional, Union +from packaging import version + import torch from torch.optim.lr_scheduler import _LRScheduler +from skrl import config + class KLAdaptiveLR(_LRScheduler): def __init__(self, @@ -25,6 +29,10 @@ def __init__(self, This scheduler is only available for PPO at the moment. Applying it to other agents will not change the learning rate + .. note:: + + In distributed runs, the learning rate will be reduced and broadcasted across all workers/processes + Example:: >>> scheduler = KLAdaptiveLR(optimizer, kl_threshold=0.01) @@ -50,6 +58,8 @@ def __init__(self, :param verbose: Verbose mode (default: ``False``) :type verbose: bool, optional """ + if version.parse(torch.__version__) >= version.parse("2.2"): + verbose = "deprecated" super().__init__(optimizer, last_epoch, verbose) self.kl_threshold = kl_threshold @@ -82,10 +92,25 @@ def step(self, kl: Optional[Union[torch.Tensor, float]] = None, epoch: Optional[ :type epoch: int, optional """ if kl is not None: - for group in self.optimizer.param_groups: + # reduce (collect from all workers/processes) learning rate in distributed runs + if config.torch.is_distributed: + torch.distributed.all_reduce(kl, op=torch.distributed.ReduceOp.SUM) + kl /= config.torch.world_size + + for i, group in enumerate(self.optimizer.param_groups): + # adjust the learning rate + lr = group['lr'] if kl > self.kl_threshold * self._kl_factor: - group['lr'] = max(group['lr'] / self._lr_factor, self.min_lr) + lr = max(lr / self._lr_factor, self.min_lr) elif kl < self.kl_threshold / self._kl_factor: - group['lr'] = min(group['lr'] * self._lr_factor, self.max_lr) + lr = min(lr * self._lr_factor, self.max_lr) + + # broadcast learning rate in distributed runs + if config.torch.is_distributed: + lr_tensor = torch.tensor([lr], device=config.torch.device) + torch.distributed.broadcast(lr_tensor, 0) + lr = lr_tensor.item() - self._last_lr = [group['lr'] for group in self.optimizer.param_groups] + # update value + group['lr'] = lr + self._last_lr[i] = lr diff --git a/skrl/trainers/torch/base.py b/skrl/trainers/torch/base.py index 232e01ee..8784185b 100644 --- a/skrl/trainers/torch/base.py +++ b/skrl/trainers/torch/base.py @@ -6,7 +6,7 @@ import torch -from skrl import logger +from skrl import config, logger from skrl.agents.torch import Agent from skrl.envs.wrappers.torch import Wrapper @@ -75,6 +75,11 @@ def close_env(): self.env.close() logger.info("Environment closed") + # update trainer configuration to avoid duplicated info/data in distributed runs + if config.torch.is_distributed: + if config.torch.rank: + self.disable_progressbar = True + def __str__(self) -> str: """Generate a string representation of the trainer diff --git a/skrl/utils/__init__.py b/skrl/utils/__init__.py index e67e66cf..6ebdccf1 100644 --- a/skrl/utils/__init__.py +++ b/skrl/utils/__init__.py @@ -14,8 +14,14 @@ def set_seed(seed: Optional[int] = None, deterministic: bool = False) -> int: """ Set the seed for the random number generators - Due to NumPy's legacy seeding constraint the seed must be between 0 and 2**32 - 1. - Otherwise a NumPy exception (``ValueError: Seed must be between 0 and 2**32 - 1``) will be raised + .. note:: + + In distributed runs, the worker/process seed will be incremented (counting from the defined value) according to its rank + + .. warning:: + + Due to NumPy's legacy seeding constraint the seed must be between 0 and 2**32 - 1. + Otherwise a NumPy exception (``ValueError: Seed must be between 0 and 2**32 - 1``) will be raised Modified packages: @@ -65,8 +71,12 @@ def set_seed(seed: Optional[int] = None, deterministic: bool = False) -> int: except NotImplementedError: seed = int(time.time() * 1000) seed %= 2 ** 31 # NumPy's legacy seeding seed must be between 0 and 2**32 - 1 - seed = int(seed) + + # set different seeds in distributed runs + if config.torch.is_distributed: + seed += config.torch.rank + logger.info(f"Seed: {seed}") # numpy