Skip to content

Commit 22819f4

Browse files
authored
[Bug fix][CPU] Fix fp8 sdpa compiling issue with latest PyTorch (#2991)
* [CPU] Fix fp8 sdpa compiling issue with latest pytorch * disable fp8 fusion
1 parent d2fae7a commit 22819f4

File tree

2 files changed

+27
-45
lines changed

2 files changed

+27
-45
lines changed

torchao/csrc/cpu/aten_kernels/quantized_sdpa.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1775,6 +1775,7 @@ int8_sdpa_fused_kernel_impl(
17751775
at::native::cpublas::brgemm_release();
17761776
}
17771777

1778+
#if defined(CPUBLAS_BRGEMM_F8F8F32)
17781779
// FP8 - kernel with f8f8f8 GEMM
17791780
template <typename scalar_t, typename mask_t,
17801781
int64_t q_split_size, int64_t kv_split_size>
@@ -2136,6 +2137,7 @@ fp8_sdpa_fused_kernel_impl(
21362137
at::native::cpublas::brgemm_release();
21372138
});
21382139
}
2140+
#endif // CPUBLAS_BRGEMM_F8F8F32
21392141

21402142
template <typename scalar_t, typename mask_t, int64_t q_split_size, int64_t kv_split_size>
21412143
inline typename std::enable_if_t<std::is_same_v<scalar_t, unsigned char>, void>
@@ -2304,6 +2306,7 @@ void int8_sdpa_fused_kernel(
23042306
}
23052307
}
23062308

2309+
#if defined(CPUBLAS_BRGEMM_F8F8F32)
23072310
void fp8_sdpa_fused_kernel(
23082311
const at::Tensor& output,
23092312
const at::Tensor& query,
@@ -2380,6 +2383,7 @@ void fp8_sdpa_fused_kernel(
23802383
});
23812384
}
23822385
}
2386+
#endif // CPUBLAS_BRGEMM_F8F8F32
23832387
#endif // CPU_CAPABILITY_AVX512
23842388

23852389
at::Tensor int8_sdpa_math_kernel(

torchao/prototype/inductor/fx_passes/qsdpa_fusion.py

Lines changed: 23 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
]
2929

3030
aten = torch.ops.aten
31-
quantize_dtypes = [torch.uint8, torch.float8_e4m3fn]
31+
quantize_dtypes = [torch.uint8]
3232

3333

3434
def _is_valid_qsdpa_pattern():
@@ -121,53 +121,31 @@ def qsdpa(match: Match, *args, **kwargs):
121121
def _generate_dequant_pattern(
122122
input_pattern, qtype, is_reduced_type, scale: str, zp: str = None
123123
):
124-
if qtype == torch.uint8:
125-
assert zp is not None, "Zero point must be provided for uint8 dequantization"
126-
return CallFunction(
127-
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
128-
input_pattern,
129-
KeywordArg(scale),
130-
KeywordArg(zp),
131-
Arg(),
132-
Arg(),
133-
Arg(),
134-
)
135-
else:
136-
assert zp is None, "Fp8 dequantization does not support zero point"
137-
if is_reduced_type:
138-
return CallFunction(
139-
torch.ops.torchao.dequantize_affine_float8.default,
140-
input_pattern,
141-
KeywordArg(scale),
142-
Arg(),
143-
)
144-
else:
145-
return CallFunction(
146-
torch.ops.torchao.dequantize_affine_float8.default,
147-
input_pattern,
148-
KeywordArg(scale),
149-
)
124+
assert qtype is torch.uint8, "QSDPA expects type to be uint8"
125+
assert zp is not None, "Zero point must be provided for uint8 dequantization"
126+
return CallFunction(
127+
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
128+
input_pattern,
129+
KeywordArg(scale),
130+
KeywordArg(zp),
131+
Arg(),
132+
Arg(),
133+
Arg(),
134+
)
150135

151136

152137
def _generate_quant_pattern(input_pattern, qtype, scale: str, zp: str = None):
153-
if qtype == torch.uint8:
154-
assert zp is not None, "Zero point must be provided for uint8 quantization"
155-
return CallFunction(
156-
torch.ops.quantized_decomposed.quantize_per_tensor.default,
157-
input_pattern,
158-
KeywordArg(scale),
159-
KeywordArg(zp),
160-
Arg(),
161-
Arg(),
162-
Arg(),
163-
)
164-
else:
165-
assert zp is None, "Fp8 quantization does not support zero point"
166-
return CallFunction(
167-
torch.ops.torchao.quantize_affine_float8.default,
168-
input_pattern,
169-
KeywordArg(scale),
170-
)
138+
assert qtype is torch.uint8, "QSDPA expects type to be uint8"
139+
assert zp is not None, "Zero point must be provided for uint8 quantization"
140+
return CallFunction(
141+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
142+
input_pattern,
143+
KeywordArg(scale),
144+
KeywordArg(zp),
145+
Arg(),
146+
Arg(),
147+
Arg(),
148+
)
171149

172150

173151
def _get_qsdpa_qkv_pattern(

0 commit comments

Comments
 (0)