Skip to content

Commit

Permalink
Define default jax device
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Jul 7, 2024
1 parent 1b056f1 commit 8f88ecf
Showing 1 changed file with 47 additions and 20 deletions.
67 changes: 47 additions & 20 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,25 @@ def __init__(self) -> None:
self._world_size = int(os.getenv("WORLD_SIZE", "1"))
self._is_distributed = self._world_size > 1

@property
def device(self) -> "torch.device":
"""Default device
The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment)
if CUDA is available, ``cpu`` otherwise
"""
try:
import torch
if self._device is None:
return torch.device(f"cuda:{self._local_rank}" if torch.cuda.is_available() else "cpu")
return torch.device(self._device)
except ImportError:
return self._device

@device.setter
def device(self, device: Union[str, "torch.device"]) -> None:
self._device = device

@property
def local_rank(self) -> int:
"""The rank of the worker/process (e.g.: GPU) within a local worker group (e.g.: node)
Expand Down Expand Up @@ -88,25 +107,6 @@ def is_distributed(self) -> bool:
"""
return self._is_distributed

@property
def device(self) -> "torch.device":
"""Default device
The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment)
if CUDA is available, ``cpu`` otherwise
"""
try:
import torch
if self._device is None:
return torch.device(f"cuda:{self._local_rank}" if torch.cuda.is_available() else "cpu")
return torch.device(self._device)
except ImportError:
return self._device

@device.setter
def device(self, device: Union[str, "torch.device"]) -> None:
self._device = device

class JAX(object):
def __init__(self) -> None:
"""JAX configuration
Expand All @@ -121,6 +121,8 @@ def __init__(self) -> None:
self._world_size = int(os.getenv("WORLD_SIZE", "1"))
self._coordinator_address = os.getenv("MASTER_ADDR", "127.0.0.1") + ":" + os.getenv("MASTER_PORT", "1234")
self._is_distributed = self._world_size > 1
# device
self._device = f"cuda:{self._local_rank}"

# TODO: find a better place for it
# set up distributed runs
Expand All @@ -132,6 +134,31 @@ def __init__(self) -> None:
process_id=self._rank,
local_device_ids=self._local_rank)

@property
def device(self) -> "jax.Device":
"""Default device
The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment)
if CUDA is available, ``cpu`` otherwise
"""
try:
import jax
if type(self._device) == str:
device_type, device_index = f"{self._device}:0".split(':')[:2]
try:
self._device = jax.devices(device_type)[int(device_index)]
except (RuntimeError, IndexError):
self._device = None
if self._device is None:
self._device = jax.devices()[0]
except ImportError:
pass
return self._device

@device.setter
def device(self, device: Union[str, "jax.Device"]) -> None:
self._device = device

@property
def backend(self) -> str:
"""Backend used by the different components to operate and generate arrays
Expand All @@ -154,7 +181,7 @@ def key(self) -> "jax.Array":
if isinstance(self._key, np.ndarray):
try:
import jax
with jax.default_device(jax.devices("cpu")[0]):
with jax.default_device(self.device):
self._key = jax.random.PRNGKey(self._key[1])
except ImportError:
pass
Expand Down

0 comments on commit 8f88ecf

Please sign in to comment.