Skip to content

Commit

Permalink
Add method to parse device in torch
Browse files Browse the repository at this point in the history
  • Loading branch information
Toni-SM committed Dec 5, 2024
1 parent 57f60df commit a08bf91
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions skrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,13 @@ def __init__(self) -> None:
class PyTorch(object):
def __init__(self) -> None:
"""PyTorch configuration"""
self._device = None
# torch.distributed config
self._local_rank = int(os.getenv("LOCAL_RANK", "0"))
self._rank = int(os.getenv("RANK", "0"))
self._world_size = int(os.getenv("WORLD_SIZE", "1"))
self._is_distributed = self._world_size > 1
# device
self._device = f"cuda:{self._local_rank}"

# set up distributed runs
if self._is_distributed:
Expand All @@ -68,21 +69,37 @@ def __init__(self) -> None:
torch.distributed.init_process_group("nccl", rank=self._rank, world_size=self._world_size)
torch.cuda.set_device(self._local_rank)

@staticmethod
def parse_device(device: Union[str, "torch.device", None]) -> "torch.device":
"""Parse the input device and return a :py:class:`~torch.device` instance.
:param device: Device specification. If the specified device is ``None`` or it cannot be resolved,
the default available device will be returned instead.
:return: PyTorch device.
"""
import torch

if isinstance(device, torch.device):
return device
elif isinstance(device, str):
try:
return torch.device(device)
except RuntimeError as e:
logger.warning(f"Invalid device specification ({device}): {e}")
return torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu"
) # torch.get_default_device() was introduced in version 2.3.0

@property
def device(self) -> "torch.device":
"""Default device
The default device, unless specified, is ``cuda:0`` (or ``cuda:LOCAL_RANK`` in a distributed environment)
if CUDA is available, ``cpu`` otherwise
"""
try:
import torch

if self._device is None:
return torch.device(f"cuda:{self._local_rank}" if torch.cuda.is_available() else "cpu")
return torch.device(self._device)
except ImportError:
return self._device
self._device = self.parse_device(self._device)
return self._device

@device.setter
def device(self, device: Union[str, "torch.device"]) -> None:
Expand Down

0 comments on commit a08bf91

Please sign in to comment.