|
28 | 28 | ]
|
29 | 29 |
|
30 | 30 | aten = torch.ops.aten
|
31 |
| -quantize_dtypes = [torch.uint8, torch.float8_e4m3fn] |
| 31 | +quantize_dtypes = [torch.uint8] |
32 | 32 |
|
33 | 33 |
|
34 | 34 | def _is_valid_qsdpa_pattern():
|
@@ -121,53 +121,31 @@ def qsdpa(match: Match, *args, **kwargs):
|
121 | 121 | def _generate_dequant_pattern(
|
122 | 122 | input_pattern, qtype, is_reduced_type, scale: str, zp: str = None
|
123 | 123 | ):
|
124 |
| - if qtype == torch.uint8: |
125 |
| - assert zp is not None, "Zero point must be provided for uint8 dequantization" |
126 |
| - return CallFunction( |
127 |
| - torch.ops.quantized_decomposed.dequantize_per_tensor.default, |
128 |
| - input_pattern, |
129 |
| - KeywordArg(scale), |
130 |
| - KeywordArg(zp), |
131 |
| - Arg(), |
132 |
| - Arg(), |
133 |
| - Arg(), |
134 |
| - ) |
135 |
| - else: |
136 |
| - assert zp is None, "Fp8 dequantization does not support zero point" |
137 |
| - if is_reduced_type: |
138 |
| - return CallFunction( |
139 |
| - torch.ops.torchao.dequantize_affine_float8.default, |
140 |
| - input_pattern, |
141 |
| - KeywordArg(scale), |
142 |
| - Arg(), |
143 |
| - ) |
144 |
| - else: |
145 |
| - return CallFunction( |
146 |
| - torch.ops.torchao.dequantize_affine_float8.default, |
147 |
| - input_pattern, |
148 |
| - KeywordArg(scale), |
149 |
| - ) |
| 124 | + assert qtype is torch.uint8, "QSDPA expects type to be uint8" |
| 125 | + assert zp is not None, "Zero point must be provided for uint8 dequantization" |
| 126 | + return CallFunction( |
| 127 | + torch.ops.quantized_decomposed.dequantize_per_tensor.default, |
| 128 | + input_pattern, |
| 129 | + KeywordArg(scale), |
| 130 | + KeywordArg(zp), |
| 131 | + Arg(), |
| 132 | + Arg(), |
| 133 | + Arg(), |
| 134 | + ) |
150 | 135 |
|
151 | 136 |
|
152 | 137 | def _generate_quant_pattern(input_pattern, qtype, scale: str, zp: str = None):
|
153 |
| - if qtype == torch.uint8: |
154 |
| - assert zp is not None, "Zero point must be provided for uint8 quantization" |
155 |
| - return CallFunction( |
156 |
| - torch.ops.quantized_decomposed.quantize_per_tensor.default, |
157 |
| - input_pattern, |
158 |
| - KeywordArg(scale), |
159 |
| - KeywordArg(zp), |
160 |
| - Arg(), |
161 |
| - Arg(), |
162 |
| - Arg(), |
163 |
| - ) |
164 |
| - else: |
165 |
| - assert zp is None, "Fp8 quantization does not support zero point" |
166 |
| - return CallFunction( |
167 |
| - torch.ops.torchao.quantize_affine_float8.default, |
168 |
| - input_pattern, |
169 |
| - KeywordArg(scale), |
170 |
| - ) |
| 138 | + assert qtype is torch.uint8, "QSDPA expects type to be uint8" |
| 139 | + assert zp is not None, "Zero point must be provided for uint8 quantization" |
| 140 | + return CallFunction( |
| 141 | + torch.ops.quantized_decomposed.quantize_per_tensor.default, |
| 142 | + input_pattern, |
| 143 | + KeywordArg(scale), |
| 144 | + KeywordArg(zp), |
| 145 | + Arg(), |
| 146 | + Arg(), |
| 147 | + Arg(), |
| 148 | + ) |
171 | 149 |
|
172 | 150 |
|
173 | 151 | def _get_qsdpa_qkv_pattern(
|
|
0 commit comments