diff --git a/docs/source/api/config/frameworks.rst b/docs/source/api/config/frameworks.rst index f25e492f..a4c974b8 100644 --- a/docs/source/api/config/frameworks.rst +++ b/docs/source/api/config/frameworks.rst @@ -27,7 +27,7 @@ API The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise -.. py:data:: skrl.config.local_rank +.. py:data:: skrl.config.torch.local_rank :type: int :value: 0 @@ -36,7 +36,7 @@ API This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist). See `torch.distributed `_ for more details -.. py:data:: skrl.config.rank +.. py:data:: skrl.config.torch.rank :type: int :value: 0 @@ -45,7 +45,7 @@ API This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist). See `torch.distributed `_ for more details -.. py:data:: skrl.config.world_size +.. py:data:: skrl.config.torch.world_size :type: int :value: 1 @@ -54,7 +54,7 @@ API This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist). See `torch.distributed `_ for more details -.. py:data:: skrl.config.is_distributed +.. py:data:: skrl.config.torch.is_distributed :type: bool :value: False @@ -78,6 +78,14 @@ JAX specific configuration API ^^^ +.. py:data:: skrl.config.jax.device + :type: jax.Device + :value: "cuda:${LOCAL_RANK}" | "cpu" + + Default device + + The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment) if CUDA is available, ``cpu`` otherwise + .. py:data:: skrl.config.jax.backend :type: str :value: "numpy" @@ -92,3 +100,43 @@ API :value: [0, 0] Pseudo-random number generator (PRNG) key + +.. py:data:: skrl.config.jax.local_rank + :type: int + :value: 0 + + The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node) + + This property reads from the ``LOCAL_RANK`` environment variable (``0`` if it doesn't exist). + +.. py:data:: skrl.config.jax.rank + :type: int + :value: 0 + + The rank of the worker/process (e.g.: GPU) within a worker group (e.g.: across all nodes) + + This property reads from the ``RANK`` environment variable (``0`` if it doesn't exist). + +.. py:data:: skrl.config.jax.world_size + :type: int + :value: 1 + + The total number of workers/process (e.g.: GPUs) in a worker group (e.g.: across all nodes) + + This property reads from the ``WORLD_SIZE`` environment variable (``1`` if it doesn't exist). + +.. py:data:: skrl.config.jax.coordinator_address + :type: str + :value: "127.0.0.1:1234" + + IP address and port where process 0 will start a JAX service + + This property reads from the ``MASTER_ADDR:MASTER_PORT`` environment variables (``127.0.0.1:1234`` if they don't exist) + +.. py:data:: skrl.config.jax.is_distributed + :type: bool + :value: False + + Whether if running in a distributed environment + + This property is ``True`` when the JAX's distributed environment variable ``WORLD_SIZE > 1``