Skip to content

Commit 4ae3f94

Browse files
committed
Add option to validate parsed torch device
1 parent cf05a50 commit 4ae3f94

File tree

2 files changed

+41
-6
lines changed

2 files changed

+41
-6
lines changed

skrl/__init__.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,26 +70,38 @@ def __init__(self) -> None:
7070
torch.cuda.set_device(self._local_rank)
7171

7272
@staticmethod
73-
def parse_device(device: Union[str, "torch.device", None]) -> "torch.device":
73+
def parse_device(device: Union[str, "torch.device", None], validate: bool = True) -> "torch.device":
7474
"""Parse the input device and return a :py:class:`~torch.device` instance.
7575
7676
:param device: Device specification. If the specified device is ``None`` or it cannot be resolved,
7777
the default available device will be returned instead.
78+
:param validate: Whether to check that the specified device is valid. Since PyTorch does not check if
79+
the specified device index is valid, a tensor is created for the verification.
7880
7981
:return: PyTorch device.
8082
"""
8183
import torch
8284

85+
_device = None
8386
if isinstance(device, torch.device):
84-
return device
87+
_device = device
8588
elif isinstance(device, str):
8689
try:
87-
return torch.device(device)
90+
_device = torch.device(device)
8891
except RuntimeError as e:
8992
logger.warning(f"Invalid device specification ({device}): {e}")
90-
return torch.device(
91-
"cuda:0" if torch.cuda.is_available() else "cpu"
92-
) # torch.get_default_device() was introduced in version 2.3.0
93+
if _device is None:
94+
_device = torch.device(
95+
"cuda:0" if torch.cuda.is_available() else "cpu"
96+
) # torch.get_default_device() was introduced in version 2.3.0
97+
# validate device
98+
if validate:
99+
try:
100+
torch.zeros((1,), device=_device)
101+
except Exception as e:
102+
logger.warning(f"Invalid device specification ({device}): {e}")
103+
_device = PyTorch.parse_device(None)
104+
return _device
93105

94106
@property
95107
def device(self) -> "torch.device":

tests/torch/test_torch_config.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Union
2+
3+
import pytest
4+
5+
import torch
6+
7+
from skrl import config
8+
9+
10+
@pytest.mark.parametrize("device", [None, "cpu", "cuda", "cuda:0", "cuda:10", "edge-case"])
11+
@pytest.mark.parametrize("validate", [True, False])
12+
def test_parse_device(capsys, device: Union[str, None], validate: bool):
13+
target_device = None
14+
if device in [None, "edge-case"]:
15+
target_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
16+
elif device.startswith("cuda"):
17+
if validate and int(f"{device}:0".split(":")[1]) >= torch.cuda.device_count():
18+
target_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19+
if not target_device:
20+
target_device = torch.device(device)
21+
22+
runtime_device = config.torch.parse_device(device, validate=validate)
23+
assert runtime_device == target_device

0 commit comments

Comments
 (0)