Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed multi-GPU and multi-node learning (PyTorch implementation) #162

Merged
merged 21 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
cea621f
Add PyTorch framework configuration
Toni-SM Jun 19, 2024
b9c7645
Initialize distributed process in trainer base class
Toni-SM Jun 19, 2024
b0deb24
Improve property annotation and docstrings
Toni-SM Jun 20, 2024
42f9f57
Add PyTorch framework config to docs
Toni-SM Jun 20, 2024
e06dcfd
Increase seed according to worker rank in distributed runs
Toni-SM Jun 20, 2024
5abed23
Update agent and trainer configuration to avoid duplicated data in di…
Toni-SM Jun 20, 2024
4f8e3d8
Add method to broadcast and reduce distributed model parameters
Toni-SM Jun 20, 2024
104313d
Setup distributed runs
Toni-SM Jun 20, 2024
36d4a57
Add distributed implementation to PPO agent
Toni-SM Jun 21, 2024
9f265af
Fix torch deprecated warning
Toni-SM Jun 21, 2024
391506d
Reduce and broadcast learning rate across all workers/processes
Toni-SM Jun 21, 2024
f2aca29
Merge branch 'develop' into toni/distributed_torch
Toni-SM Jun 21, 2024
734221e
Update CHANGELOG
Toni-SM Jun 21, 2024
ffd6503
Merge branch 'develop' into toni/distributed_torch
Toni-SM Jun 23, 2024
a336bbd
Implement distributed runs for on-policy agents
Toni-SM Jun 23, 2024
1954ab3
Add distributed implementation to agent features
Toni-SM Jun 23, 2024
b821879
Implement distributed runs for off-policy agents
Toni-SM Jun 24, 2024
de86f51
Update off-policy agents features table in docs
Toni-SM Jun 24, 2024
c739b75
Unify code style for distributed implementation
Toni-SM Jun 24, 2024
c6301fb
Implement distributed runs for multi-agents
Toni-SM Jun 24, 2024
e09d7fd
Update multi-agents features table in docs
Toni-SM Jun 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
- Define the `environment_info` trainer config to log environment info (PyTorch implementation)
- Add support to automatically compute the write and checkpoint intervals and make it the default option
- Single forward-pass in shared models
- Distributed multi-GPU and multi-node learning (PyTorch implementation)

### Changed
- Update Orbit-related source code and docs to Isaac Lab
Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/a2c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/amp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/cem.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Distributed
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/ddpg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/ddqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/dqn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/ppo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/rpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/td3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/agents/trpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ Support for advanced features is described in the next table
- RNN, LSTM, GRU and any other variant
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
59 changes: 59 additions & 0 deletions docs/source/api/config/frameworks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,65 @@ Configurations for behavior modification of Machine Learning (ML) frameworks.

<br><hr>

PyTorch
-------

PyTorch specific configuration

.. raw:: html

<br>

API
^^^

.. py:data:: skrl.config.torch.device
:type: torch.device
:value: "cuda:${LOCAL_RANK}" | "cpu"

Default device

The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise

.. py:data:: skrl.config.local_rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)

This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details

.. py:data:: skrl.config.rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)

This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details

.. py:data:: skrl.config.world_size
:type: int
:value: 1

The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)

This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details

.. py:data:: skrl.config.is_distributed
:type: bool
:value: False

Whether if running in a distributed environment

This property is ``True`` when the PyTorch's distributed environment variable ``WORLD_SIZE > 1``

.. raw:: html

<br>

JAX
---

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/multi_agents/ippo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
4 changes: 4 additions & 0 deletions docs/source/api/multi_agents/mappo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,10 @@ Support for advanced features is described in the next table
- \-
- .. centered:: :math:`\square`
- .. centered:: :math:`\square`
* - Distributed
- Single Program Multi Data (SPMD) multi-GPU
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\square`

.. raw:: html

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ A set of utilities and configurations for managing an RL setup is provided as pa
- .. centered:: |_4| |pytorch| |_4|
- .. centered:: |_4| |jax| |_4|
* - :doc:`ML frameworks <config/frameworks>` configuration |_5| |_5| |_5| |_5| |_5| |_2|
- .. centered:: :math:`\square`
- .. centered:: :math:`\blacksquare`
- .. centered:: :math:`\blacksquare`

.. list-table::
Expand Down
65 changes: 65 additions & 0 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Union

import logging
import os
import sys

import numpy as np
Expand Down Expand Up @@ -43,6 +44,69 @@ class _Config(object):
def __init__(self) -> None:
"""Machine learning framework specific configuration
"""

class PyTorch(object):
def __init__(self) -> None:
"""PyTorch configuration
"""
self._device = None
# torch.distributed config
self._local_rank = int(os.getenv("LOCAL_RANK", "0"))
self._rank = int(os.getenv("RANK", "0"))
self._world_size = int(os.getenv("WORLD_SIZE", "1"))
self._is_distributed = self._world_size > 1

@property
def local_rank(self) -> int:
"""The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)

This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist)
"""
return self._local_rank

@property
def rank(self) -> int:
"""The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)

This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist)
"""
return self._rank

@property
def world_size(self) -> int:
"""The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)

This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist)
"""
return self._world_size

@property
def is_distributed(self) -> bool:
"""Whether if running in a distributed environment

This property is ``True`` when the PyTorch's distributed environment variable ``WORLD_SIZE > 1``
"""
return self._is_distributed

@property
def device(self) -> "torch.device":
"""Default device

The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment)
if CUDA is available, ``cpu`` otherwise
"""
try:
import torch
if self._device is None:
return torch.device(f"cuda:{self._local_rank}" if torch.cuda.is_available() else "cpu")
return torch.device(self._device)
except ImportError:
return self._device

@device.setter
def device(self, device: Union[str, "torch.device"]) -> None:
self._device = device

class JAX(object):
def __init__(self) -> None:
"""JAX configuration
Expand Down Expand Up @@ -91,5 +155,6 @@ def key(self, value: Union[int, "jax.Array"]) -> None:
self._key = value

self.jax = JAX()
self.torch = PyTorch()

config = _Config()
15 changes: 14 additions & 1 deletion skrl/agents/torch/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -104,6 +105,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.value is not None and self.policy is not self.value:
self.value.broadcast_parameters()

# configuration
self._mini_batches = self.cfg["mini_batches"]
self._rollouts = self.cfg["rollouts"]
Expand Down Expand Up @@ -391,6 +400,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
if config.torch.is_distributed:
self.policy.reduce_parameters()
if self.policy is not self.value:
self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
Expand All @@ -407,7 +420,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences).mean())
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()

Expand Down
15 changes: 14 additions & 1 deletion skrl/agents/torch/a2c/a2c_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch.nn as nn
import torch.nn.functional as F

from skrl import config, logger
from skrl.agents.torch import Agent
from skrl.memories.torch import Memory
from skrl.models.torch import Model
Expand Down Expand Up @@ -104,6 +105,14 @@ def __init__(self,
self.checkpoint_modules["policy"] = self.policy
self.checkpoint_modules["value"] = self.value

# broadcast models' parameters in distributed runs
if config.torch.is_distributed:
logger.info(f"Broadcasting models' parameters")
if self.policy is not None:
self.policy.broadcast_parameters()
if self.value is not None and self.policy is not self.value:
self.value.broadcast_parameters()

# configuration
self._mini_batches = self.cfg["mini_batches"]
self._rollouts = self.cfg["rollouts"]
Expand Down Expand Up @@ -462,6 +471,10 @@ def compute_gae(rewards: torch.Tensor,
# optimization step
self.optimizer.zero_grad()
(policy_loss + entropy_loss + value_loss).backward()
if config.torch.is_distributed:
self.policy.reduce_parameters()
if self.policy is not self.value:
self.value.reduce_parameters()
if self._grad_norm_clip > 0:
if self.policy is self.value:
nn.utils.clip_grad_norm_(self.policy.parameters(), self._grad_norm_clip)
Expand All @@ -478,7 +491,7 @@ def compute_gae(rewards: torch.Tensor,
# update learning rate
if self._learning_rate_scheduler:
if isinstance(self.scheduler, KLAdaptiveLR):
self.scheduler.step(torch.tensor(kl_divergences).mean())
self.scheduler.step(torch.tensor(kl_divergences, device=self.device).mean())
else:
self.scheduler.step()

Expand Down
Loading
Loading