Skip to content

Commit

Permalink
Modify the environment variable names for distributed runs to differe…
Browse files Browse the repository at this point in the history
…ntiate them from torch
  • Loading branch information
Toni-SM committed Jul 11, 2024
1 parent 9915630 commit 828355d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
10 changes: 5 additions & 5 deletions docs/source/api/config/frameworks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ API

Default device

The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise
The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise

.. py:data:: skrl.config.jax.backend
:type: str
Expand All @@ -107,31 +107,31 @@ API

The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)

This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist).
This property reads from the ``JAX_LOCAL_RANK`` environment variable (``0`` if it doesn't exist).

.. py:data:: skrl.config.jax.rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)

This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist).
This property reads from the ``JAX_RANK`` environment variable (``0`` if it doesn't exist).

.. py:data:: skrl.config.jax.world_size
:type: int
:value: 1

The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)

This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist).
This property reads from the ``JAX_WORLD_SIZE`` environment variable (``1`` if it doesn't exist).

.. py:data:: skrl.config.jax.coordinator_address
:type: str
:value: "127.0.0.1:1234"

IP address and port where process 0 will start a JAX service

This property reads from the ``MASTER_ADDR:MASTER_PORT`` environment variables (``127.0.0.1:1234`` if they don't exist)
This property reads from the ``JAX_COORDINATOR_ADDR:JAX_COORDINATOR_PORT`` environment variables (``127.0.0.1:1234`` if they don't exist)

.. py:data:: skrl.config.jax.is_distributed
:type: bool
Expand Down
21 changes: 11 additions & 10 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def __init__(self) -> None:
# distributed config (based on torch.distributed, since JAX doesn't implement it)
# JAX doesn't automatically start multiple processes from a single program invocation
# https://jax.readthedocs.io/en/latest/multi_process.html#launching-jax-processes
self._local_rank = int(os.getenv("LOCAL_RANK", "0"))
self._rank = int(os.getenv("RANK", "0"))
self._world_size = int(os.getenv("WORLD_SIZE", "1"))
self._coordinator_address = os.getenv("MASTER_ADDR", "127.0.0.1") + ":" + os.getenv("MASTER_PORT", "1234")
self._local_rank = int(os.getenv("JAX_LOCAL_RANK", "0"))
self._rank = int(os.getenv("JAX_RANK", "0"))
self._world_size = int(os.getenv("JAX_WORLD_SIZE", "1"))
self._coordinator_address = os.getenv("JAX_COORDINATOR_ADDR", "127.0.0.1") + ":" + os.getenv("JAX_COORDINATOR_PORT", "1234")
self._is_distributed = self._world_size > 1
# device
self._device = f"cuda:{self._local_rank}"
Expand All @@ -138,7 +138,7 @@ def __init__(self) -> None:
def device(self) -> "jax.Device":
"""Default device
The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment)
The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment)
if CUDA is available, ``cpu`` otherwise
"""
try:
Expand Down Expand Up @@ -197,39 +197,40 @@ def key(self, value: Union[int, "jax.Array"]) -> None:
def local_rank(self) -> int:
"""The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)
This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist)
This property reads from the ``JAX_LOCAL_RANK`` environment variable (``0`` if it doesn't exist)
"""
return self._local_rank

@property
def rank(self) -> int:
"""The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)
This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist)
This property reads from the ``JAX_RANK`` environment variable (``0`` if it doesn't exist)
"""
return self._rank

@property
def world_size(self) -> int:
"""The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)
This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist)
This property reads from the ``JAX_WORLD_SIZE`` environment variable (``1`` if it doesn't exist)
"""
return self._world_size

@property
def coordinator_address(self) -> int:
"""IP address and port where process 0 will start a JAX service
This property reads from the ``MASTER_ADDR:MASTER_PORT`` environment variables (``127.0.0.1:1234`` if they don't exist)
This property reads from the ``JAX_COORDINATOR_ADDR:JAX_COORDINATOR_PORT`` environment variables
(``127.0.0.1:1234`` if they don't exist)
"""
return self._coordinator_address

@property
def is_distributed(self) -> bool:
"""Whether if running in a distributed environment
This property is ``True`` when the JAX's distributed environment variable ``WORLD_SIZE > 1``
This property is ``True`` when the JAX's distributed environment variable ``JAX_WORLD_SIZE > 1``
"""
return self._is_distributed

Expand Down

0 comments on commit 828355d

Please sign in to comment.