-
Notifications
You must be signed in to change notification settings - Fork 195
Expand file tree
/
Copy pathdatatypes.py
More file actions
170 lines (132 loc) · 6.29 KB
/
Copy pathdatatypes.py
File metadata and controls
170 lines (132 loc) · 6.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import sys
import importlib
def is_windows():
return sys.platform.startswith("win")
module_name = ".Release._compiled_module" if is_windows() else "._compiled_module"
_pybind_module = importlib.import_module(module_name, package="cudnn")
globals()["cudnn_data_type"] = getattr(_pybind_module, "data_type")
torch_available = None
_torch_to_cudnn_data_type_dict = None
# Optional CUTLASS integration
cutlass_available = None
_torch_to_cutlass_data_type_dict = None
def is_torch_available():
global torch_available, _torch_to_cudnn_data_type_dict
# this condition ensures that datatype mapping is only created once
if torch_available is None:
try:
import torch
torch_available = True
_torch_to_cudnn_data_type_dict = {
torch.half: cudnn_data_type.HALF,
torch.float16: cudnn_data_type.HALF,
torch.bfloat16: cudnn_data_type.BFLOAT16,
torch.float: cudnn_data_type.FLOAT,
torch.float32: cudnn_data_type.FLOAT,
torch.double: cudnn_data_type.DOUBLE,
torch.float64: cudnn_data_type.DOUBLE,
torch.int8: cudnn_data_type.INT8,
torch.int32: cudnn_data_type.INT32,
torch.int64: cudnn_data_type.INT64,
torch.uint8: cudnn_data_type.UINT8,
torch.bool: cudnn_data_type.BOOLEAN,
}
def possibly_add_type(torch_type_name, cudnn_type):
# Only try adding the type if the version of torch being used supports it
if hasattr(torch, torch_type_name):
torch_type = getattr(torch, torch_type_name)
_torch_to_cudnn_data_type_dict[torch_type] = cudnn_type
possibly_add_type("float8_e4m3fn", cudnn_data_type.FP8_E4M3)
possibly_add_type("float8_e5m2", cudnn_data_type.FP8_E5M2)
possibly_add_type("float8_e8m0fnu", cudnn_data_type.FP8_E8M0)
possibly_add_type("float4_e2m1fn_x2", cudnn_data_type.FP4_E2M1)
except ImportError:
torch_available = False
_torch_to_cudnn_data_type_dict = {}
return torch_available
def is_cutlass_available():
global cutlass_available, _torch_to_cutlass_data_type_dict
if cutlass_available is None:
try:
import torch
import cutlass
cutlass_available = True
mapping = {
torch.half: getattr(cutlass, "Float16", None),
getattr(torch, "float16", torch.half): getattr(cutlass, "Float16", None),
getattr(torch, "bfloat16", None): getattr(cutlass, "BFloat16", None),
torch.float: getattr(cutlass, "Float32", None),
getattr(torch, "float32", torch.float): getattr(cutlass, "Float32", None),
torch.double: getattr(cutlass, "Float64", None),
getattr(torch, "float64", torch.double): getattr(cutlass, "Float64", None),
getattr(torch, "int8", None): getattr(cutlass, "Int8", None),
getattr(torch, "int32", None): getattr(cutlass, "Int32", None),
getattr(torch, "int64", None): getattr(cutlass, "Int64", None),
getattr(torch, "uint8", None): getattr(cutlass, "Uint8", None),
getattr(torch, "bool", None): getattr(cutlass, "Boolean", None),
getattr(torch, "float8_e4m3fn", None): getattr(cutlass, "Float8E4M3FN", None),
getattr(torch, "float8_e5m2", None): getattr(cutlass, "Float8E5M2", None),
getattr(torch, "float8_e8m0fnu", None): getattr(cutlass, "Float8E8M0FNU", None),
getattr(torch, "float4_e2m1fn_x2", None): getattr(cutlass, "Float4E2M1FN", None),
}
_torch_to_cutlass_data_type_dict = {t: c for t, c in mapping.items() if t is not None and c is not None}
except ImportError:
cutlass_available = False
_torch_to_cutlass_data_type_dict = {}
return cutlass_available
# Returns None in case mapping is not available
def _torch_to_cudnn_data_type(torch_data_type) -> cudnn_data_type:
if is_torch_available():
return _torch_to_cudnn_data_type_dict.get(torch_data_type, None)
else:
return None
def _torch_to_cutlass_data_type(data_type, interpret_uint8_as_fp4x2: bool = False):
if is_cutlass_available() and is_torch_available():
import torch
if interpret_uint8_as_fp4x2 and data_type == torch.uint8:
import cutlass
return getattr(cutlass, "Float4E2M1FN", None)
else:
return _torch_to_cutlass_data_type_dict.get(data_type, None)
return None
def _convert_to_cutlass_data_type(data_type, interpret_uint8_as_fp4x2: bool = False):
if is_cutlass_available():
import cutlass
if isinstance(data_type, type) and issubclass(data_type, cutlass.Numeric):
return data_type
elif data_type is not None:
cutlass_data_type = _torch_to_cutlass_data_type(data_type, interpret_uint8_as_fp4x2=interpret_uint8_as_fp4x2)
if cutlass_data_type is None:
raise ValueError("Unsupported tensor data type.")
return cutlass_data_type
else:
raise ValueError("None is not a valid tensor data type.")
return None
def _cudnn_to_torch_data_type(cudnn_data_type):
"""Convert a cuDNN data type to a PyTorch data type.
Args:
cudnn_data_type: The cuDNN data type to convert.
Returns:
The PyTorch data type, or None if the conversion is not available.
"""
if is_torch_available():
for torch_type, cudnn_type in _torch_to_cudnn_data_type_dict.items():
if cudnn_type == cudnn_data_type:
return torch_type
return None
def _library_type(input_type):
if type(input_type) is cudnn_data_type:
return input_type
for cvt_fn in [
_torch_to_cudnn_data_type,
# Add more DL libraries to support here
]:
out = cvt_fn(input_type)
if out is not None:
return out
raise Exception(f"No available conversion from type {input_type} to a library type.")
def _is_torch_tensor(input_tensor) -> bool:
if is_torch_available():
import torch
return isinstance(input_tensor, torch.Tensor)
return False