Skip to content

Commit

Permalink
Use ML framework specific device parsing in source code (#234)
Browse files Browse the repository at this point in the history
* Add method to parse device in torch

* Use ML framework configuration device parsing method to parse devices

* Add option to validate parsed torch device

* Add ML framework testing in jax

* Add torch parse_device method to ML framework docs

* Update docstrings and test content

* Update docs

* Disable device parsing validation when PyTorch config device
  • Loading branch information
Toni-SM authored Dec 6, 2024
1 parent 57f60df commit 87250fa
Show file tree
Hide file tree
Showing 18 changed files with 331 additions and 153 deletions.
72 changes: 47 additions & 25 deletions docs/source/api/config/frameworks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,48 +19,58 @@ PyTorch specific configuration
API
^^^

.. autofunction:: skrl.config.torch.parse_device

.. py:data:: skrl.config.torch.device
:type: torch.device
:value: "cuda:${LOCAL_RANK}" | "cpu"

Default device
Default device.

The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise
The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise.

.. py:data:: skrl.config.torch.local_rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)
The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node).

This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details.

Read-only attribute.

.. py:data:: skrl.config.torch.rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)
The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes).

This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details.

Read-only attribute.

.. py:data:: skrl.config.torch.world_size
:type: int
:value: 1

The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)
The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes).

This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist).
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details
See `torch.distributed <https://pytorch.org/docs/stable/distributed.html>`_ for more details.

Read-only attribute.

.. py:data:: skrl.config.torch.is_distributed
:type: bool
:value: False

Whether if running in a distributed environment
Whether if running in a distributed environment.

This property is ``True`` when the PyTorch's distributed environment variable ``WORLD_SIZE > 1``
This property is ``True`` when the PyTorch's distributed environment variable ``WORLD_SIZE > 1``.

Read-only attribute.

.. raw:: html

Expand All @@ -78,67 +88,79 @@ JAX specific configuration
API
^^^

.. autofunction:: skrl.config.jax.parse_device

.. py:data:: skrl.config.jax.device
:type: jax.Device
:value: "cuda:${LOCAL_RANK}" | "cpu"

Default device
Default device.

The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise

.. autofunction:: skrl.config.jax.parse_device
The default device, unless specified, is ``cuda:0`` (or ``cuda:JAX_LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise.

.. py:data:: skrl.config.jax.backend
:type: str
:value: "numpy"

Backend used by the different components to operate and generate arrays
Backend used by the different components to operate and generate arrays.

This configuration excludes models and optimizers.
Supported backend are: ``"numpy"`` and ``"jax"``
Supported backend are: ``"numpy"`` and ``"jax"``.

.. py:data:: skrl.config.jax.key
:type: jnp.ndarray
:type: jax.Array
:value: [0, 0]

Pseudo-random number generator (PRNG) key
Pseudo-random number generator (PRNG) key.

Key is formatted as 32-bit unsigned integer and the default device is used.

.. py:data:: skrl.config.jax.local_rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)
The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node).

This property reads from the ``JAX_LOCAL_RANK`` environment variable (``0`` if it doesn't exist).

Read-only attribute.

.. py:data:: skrl.config.jax.rank
:type: int
:value: 0

The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes)
The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes).

This property reads from the ``JAX_RANK`` environment variable (``0`` if it doesn't exist).

Read-only attribute.

.. py:data:: skrl.config.jax.world_size
:type: int
:value: 1

The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes)
The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes).

This property reads from the ``JAX_WORLD_SIZE`` environment variable (``1`` if it doesn't exist).

Read-only attribute.

.. py:data:: skrl.config.jax.coordinator_address
:type: str
:value: "127.0.0.1:1234"

IP address and port where process 0 will start a JAX service
IP address and port where process 0 will start a JAX service.

This property reads from the ``JAX_COORDINATOR_ADDR:JAX_COORDINATOR_PORT`` environment variables (``127.0.0.1:1234`` if they don't exist).

This property reads from the ``JAX_COORDINATOR_ADDR:JAX_COORDINATOR_PORT`` environment variables (``127.0.0.1:1234`` if they don't exist)
Read-only attribute.

.. py:data:: skrl.config.jax.is_distributed
:type: bool
:value: False

Whether if running in a distributed environment
Whether if running in a distributed environment.

This property is ``True`` when the JAX's distributed environment variable ``WORLD_SIZE > 1``.

This property is ``True`` when the JAX's distributed environment variable ``WORLD_SIZE > 1``
Read-only attribute.
Loading

0 comments on commit 87250fa

Please sign in to comment.