@@ -70,26 +70,38 @@ def __init__(self) -> None:
70
70
torch .cuda .set_device (self ._local_rank )
71
71
72
72
@staticmethod
73
- def parse_device (device : Union [str , "torch.device" , None ]) -> "torch.device" :
73
+ def parse_device (device : Union [str , "torch.device" , None ], validate : bool = True ) -> "torch.device" :
74
74
"""Parse the input device and return a :py:class:`~torch.device` instance.
75
75
76
76
:param device: Device specification. If the specified device is ``None`` or it cannot be resolved,
77
77
the default available device will be returned instead.
78
+ :param validate: Whether to check that the specified device is valid. Since PyTorch does not check if
79
+ the specified device index is valid, a tensor is created for the verification.
78
80
79
81
:return: PyTorch device.
80
82
"""
81
83
import torch
82
84
85
+ _device = None
83
86
if isinstance (device , torch .device ):
84
- return device
87
+ _device = device
85
88
elif isinstance (device , str ):
86
89
try :
87
- return torch .device (device )
90
+ _device = torch .device (device )
88
91
except RuntimeError as e :
89
92
logger .warning (f"Invalid device specification ({ device } ): { e } " )
90
- return torch .device (
91
- "cuda:0" if torch .cuda .is_available () else "cpu"
92
- ) # torch.get_default_device() was introduced in version 2.3.0
93
+ if _device is None :
94
+ _device = torch .device (
95
+ "cuda:0" if torch .cuda .is_available () else "cpu"
96
+ ) # torch.get_default_device() was introduced in version 2.3.0
97
+ # validate device
98
+ if validate :
99
+ try :
100
+ torch .zeros ((1 ,), device = _device )
101
+ except Exception as e :
102
+ logger .warning (f"Invalid device specification ({ device } ): { e } " )
103
+ _device = PyTorch .parse_device (None )
104
+ return _device
93
105
94
106
@property
95
107
def device (self ) -> "torch.device" :
0 commit comments