Skip to content

Commit

Permalink
Add JAX config for distributed learning to docs and fix PyTorch varia…
Browse files Browse the repository at this point in the history
…bles
  • Loading branch information
Toni-SM committed Jul 8, 2024
1 parent bd2391b commit d81f4ba
Showing 1 changed file with 52 additions and 4 deletions.
56 changes: 52 additions & 4 deletions docs/source/api/config/frameworks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ API

The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise

.. py:data:: skrl.config.local_rank
.. py:data:: skrl.config.torch.local_rank
:type: int
:value: 0

Expand All @@ -36,7 +36,7 @@ API
This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details

.. py:data:: skrl.config.rank
.. py:data:: skrl.config.torch.rank
:type: int
:value: 0

Expand All @@ -45,7 +45,7 @@ API
This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details

.. py:data:: skrl.config.world_size
.. py:data:: skrl.config.torch.world_size
:type: int
:value: 1

Expand All @@ -54,7 +54,7 @@ API
This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details

.. py:data:: skrl.config.is_distributed
.. py:data:: skrl.config.torch.is_distributed
:type: bool
:value: False

Expand All @@ -78,6 +78,14 @@ JAX specific configuration
API
^^^

.. py:data:: skrl.config.jax.device
:type: jax.Device
:value: "cuda:${LOCAL_RANK}" | "cpu"

Default device

The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise

.. py:data:: skrl.config.jax.backend
:type: str
:value: "numpy"
Expand All @@ -92,3 +100,43 @@ API
:value: [0, 0]

Pseudo-random number generator (PRNG) key

.. py:data:: skrl.config.jax.local_rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)

This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist).

.. py:data:: skrl.config.jax.rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)

This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist).

.. py:data:: skrl.config.jax.world_size
:type: int
:value: 1

The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)

This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist).

.. py:data:: skrl.config.jax.coordinator_address
:type: str
:value: "127.0.0.1:1234"

IP address and port where process 0 will start a JAX service

This property reads from the ``MASTER_ADDR:MASTER_PORT`` environment variables (``127.0.0.1:1234`` if they don't exist)

.. py:data:: skrl.config.jax.is_distributed
:type: bool
:value: False

Whether if running in a distributed environment

This property is ``True`` when the JAX's distributed environment variable ``WORLD_SIZE > 1``

0 comments on commit d81f4ba

Please sign in to comment.