Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Dec 6, 2024
1 parent 14a2252 commit 6530bb9
Showing 1 changed file with 43 additions and 23 deletions.
66 changes: 43 additions & 23 deletions docs/source/api/config/frameworks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,44 +25,52 @@ API
:type: torch.device
:value: "cuda:${LOCAL_RANK}" | "cpu"

Default device
Default device.

The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise
The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise.

.. py:data:: skrl.config.torch.local_rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)
The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node).

This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details.

Read-only attribute.

.. py:data:: skrl.config.torch.rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)
The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes).

This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details.

Read-only attribute.

.. py:data:: skrl.config.torch.world_size
:type: int
:value: 1

The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)
The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes).

This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details.

Read-only attribute.

.. py:data:: skrl.config.torch.is_distributed
:type: bool
:value: False

Whether if running in a distributed environment
Whether if running in a distributed environment.

This property is ``True`` when the PyTorch's distributed environment variable ``WORLD_SIZE > 1``
This property is ``True`` when the PyTorch's distributed environment variable ``WORLD_SIZE > 1``.

Read-only attribute.

.. raw:: html

Expand All @@ -86,61 +94,73 @@ API
:type: jax.Device
:value: "cuda:${LOCAL_RANK}" | "cpu"

Default device
Default device.

The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise
The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise.

.. py:data:: skrl.config.jax.backend
:type: str
:value: "numpy"

Backend used by the different components to operate and generate arrays
Backend used by the different components to operate and generate arrays.

This configuration excludes models and optimizers.
Supported backend are: ``"numpy"`` and ``"jax"``
Supported backend are: ``"numpy"`` and ``"jax"``.

.. py:data:: skrl.config.jax.key
:type: jnp.ndarray
:type: jax.Array
:value: [0, 0]

Pseudo-random number generator (PRNG) key
Pseudo-random number generator (PRNG) key.

Key is formatted as 32-bit unsigned integer and the default device is used.

.. py:data:: skrl.config.jax.local_rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)
The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node).

This property reads from the ``JAX_LOCAL_RANK`` environment variable (``0`` if it doesn't exist).

Read-only attribute.

.. py:data:: skrl.config.jax.rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)
The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes).

This property reads from the ``JAX_RANK`` environment variable (``0`` if it doesn't exist).

Read-only attribute.

.. py:data:: skrl.config.jax.world_size
:type: int
:value: 1

The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)
The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes).

This property reads from the ``JAX_WORLD_SIZE`` environment variable (``1`` if it doesn't exist).

Read-only attribute.

.. py:data:: skrl.config.jax.coordinator_address
:type: str
:value: "127.0.0.1:1234"

IP address and port where process 0 will start a JAX service
IP address and port where process 0 will start a JAX service.

This property reads from the ``JAX_COORDINATOR_ADDR:JAX_COORDINATOR_PORT`` environment variables (``127.0.0.1:1234`` if they don't exist)
This property reads from the ``JAX_COORDINATOR_ADDR:JAX_COORDINATOR_PORT`` environment variables (``127.0.0.1:1234`` if they don't exist).

Read-only attribute.

.. py:data:: skrl.config.jax.is_distributed
:type: bool
:value: False

Whether if running in a distributed environment
Whether if running in a distributed environment.

This property is ``True`` when the JAX's distributed environment variable ``WORLD_SIZE > 1``.

This property is ``True`` when the JAX's distributed environment variable ``WORLD_SIZE > 1``
Read-only attribute.

0 comments on commit 6530bb9

Please sign in to comment.