Skip to content

Commit

Permalink
Distributed multi-GPU and multi-node learning (PyTorch implementation) (
Browse files Browse the repository at this point in the history
#162)

* Add PyTorch framework configuration

* Initialize distributed process in trainer base class

* Improve property annotation and docstrings

* Add PyTorch framework config to docs

* Increase seed according to worker rank in distributed runs

* Update agent and trainer configuration to avoid duplicated data in distributed runs

* Add method to broadcast and reduce distributed model parameters

* Setup distributed runs

* Add distributed implementation to PPO agent

* Fix torch deprecated warning

* Reduce and broadcast learning rate across all workers/processes

* Update CHANGELOG

* Implement distributed runs for on-policy agents

* Add distributed implementation to agent features

* Implement distributed runs for off-policy agents

* Update off-policy agents features table in docs

* Unify code style for distributed implementation

* Implement distributed runs for multi-agents

* Update multi-agents features table in docs
  • Loading branch information
Toni-SM committed Jun 24, 2024
1 parent 50f4a96 commit c3a23a1
Show file tree
Hide file tree
Showing 41 changed files with 557 additions and 27 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Define the `environment_info` trainer config to log environment info (PyTorch implementation)
- Add support to automatically compute the write and checkpoint intervals and make it the default option
- Single forward-pass in shared models
- Distributed multi-GPU and multi-node learning (PyTorch implementation)

### Changed
- Update Orbit-related source code and docs to Isaac Lab
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/a2c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/cem.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Distributed
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/ddpg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/ddqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/rpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/td3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/trpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
59 changes: 59 additions & 0 deletions docs/source/api/config/frameworks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,65 @@ Configurations for behavior modification of Machine Learning (ML) frameworks.

<br><hr>

PyTorch
-------

PyTorch specific configuration

.. raw:: html

<br>

API
^^^

.. py:data:: skrl.config.torch.device
:type: torch.device
:value: "cuda:${LOCAL_RANK}" | "cpu"

Default device

The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise

.. py:data:: skrl.config.local_rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)

This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details

.. py:data:: skrl.config.rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)

This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details

.. py:data:: skrl.config.world_size
:type: int
:value: 1

The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)

This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details

.. py:data:: skrl.config.is_distributed
:type: bool
:value: False

Whether if running in a distributed environment

This property is ``True`` when the PyTorch's distributed environment variable ``WORLD_SIZE > 1``

.. raw:: html

<br>

JAX
---

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/multi_agents/ippo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/multi_agents/mappo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ A set of utilities and configurations for managing an RL setup is provided as pa
- .. centered:: |_4| |pytorch| |_4|
- .. centered:: |_4| |jax| |_4|
* - :doc:`ML frameworks <config/frameworks>` configuration |_5| |_5| |_5| |_5| |_5| |_2|
- .. centered:: :math:`\square`
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\blacksquare`

.. list-table::
Expand Down
65 changes: 65 additions & 0 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union

import logging
import os
import sys

import numpy as np
Expand Down Expand Up @@ -43,6 +44,69 @@ class _Config(object):
def __init__(self) -> None:
"""Machine learning framework specific configuration
"""

class PyTorch(object):
def __init__(self) -> None:
"""PyTorch configuration
"""
self._device = None
# torch.distributed config
self._local_rank = int(os.getenv("LOCAL_RANK", "0"))
self._rank = int(os.getenv("RANK", "0"))
self._world_size = int(os.getenv("WORLD_SIZE", "1"))
self._is_distributed = self._world_size > 1

@property
def local_rank(self) -> int:
"""The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)
This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist)
"""
return self._local_rank

@property
def rank(self) -> int:
"""The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)
This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist)
"""
return self._rank

@property
def world_size(self) -> int:
"""The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)
This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist)
"""
return self._world_size

@property
def is_distributed(self) -> bool:
"""Whether if running in a distributed environment
This property is ``True`` when the PyTorch's distributed environment variable ``WORLD_SIZE > 1``
"""
return self._is_distributed

@property
def device(self) -> "torch.device":
"""Default device
The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment)
if CUDA is available, ``cpu`` otherwise
"""
try:
import torch
if self._device is None:
return torch.device(f"cuda:{self._local_rank}" if torch.cuda.is_available() else "cpu")
return torch.device(self._device)
except ImportError:
return self._device

@device.setter
def device(self, device: Union[str, "torch.device"]) -> None:
self._device = device

class JAX(object):
def __init__(self) -> None:
"""JAX configuration
Expand Down Expand Up @@ -91,5 +155,6 @@ def key(self, value: Union[int, "jax.Array"]) -> None:
self._key = value

self.jax = JAX()
self.torch = PyTorch()

config = _Config()
15 changes: 14 additions & 1 deletion skrl/agents/torch/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -104,6 +105,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.value is not None and self.policy is not self.value:
self.value.broadcast_parameters()

# configuration
self._mini_batches = self.cfg["mini_batches"]
self._rollouts = self.cfg["rollouts"]
Expand Down Expand Up @@ -391,6 +400,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
if config.torch.is_distributed:
self.policy.reduce_parameters()
if self.policy is not self.value:
self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
Expand All @@ -407,7 +420,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences).mean())
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()

Expand Down
15 changes: 14 additions & 1 deletion skrl/agents/torch/a2c/a2c_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -104,6 +105,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.value is not None and self.policy is not self.value:
self.value.broadcast_parameters()

# configuration
self._mini_batches = self.cfg["mini_batches"]
self._rollouts = self.cfg["rollouts"]
Expand Down Expand Up @@ -462,6 +471,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
if config.torch.is_distributed:
self.policy.reduce_parameters()
if self.policy is not self.value:
self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
Expand All @@ -478,7 +491,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences).mean())
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()

Expand Down
Loading

0 comments on commit c3a23a1

Please sign in to comment.