From 71ee24d947984600c3a27109a02adc68631ca540 Mon Sep 17 00:00:00 2001 From: Ea760 <15236119052@163.com> Date: Tue, 16 Sep 2025 20:12:46 +0800 Subject: [PATCH] add cpp wrapper for flash_attn_varlen_func --- ctests/CMakeLists.txt | 5 + ctests/test_triton_flash_attn_varlen.cpp | 388 ++++++++++++++++ include/flag_gems/operators.h | 100 +++++ include/flag_gems/utils.h | 9 + lib/CMakeLists.txt | 3 +- lib/flash_attn_varlen_func.cpp | 544 +++++++++++++++++++++++ src/flag_gems/csrc/cstub.cpp | 13 + 7 files changed, 1061 insertions(+), 1 deletion(-) create mode 100644 ctests/test_triton_flash_attn_varlen.cpp create mode 100644 lib/flash_attn_varlen_func.cpp diff --git a/ctests/CMakeLists.txt b/ctests/CMakeLists.txt index 29522f620..f0e3fdaa3 100644 --- a/ctests/CMakeLists.txt +++ b/ctests/CMakeLists.txt @@ -76,3 +76,8 @@ add_executable(test_triton_reshape_and_cache_flash test_triton_reshape_and_cache target_link_libraries(test_triton_reshape_and_cache_flash PRIVATE Torch::Torch operators GTest::gtest GTest::gtest_main) add_test(NAME test_triton_reshape_and_cache_flash COMMAND test_triton_reshape_and_cache_flash) + +add_executable(test_triton_flash_attn_varlen test_triton_flash_attn_varlen.cpp) +target_link_libraries(test_triton_flash_attn_varlen + PRIVATE Torch::Torch operators GTest::gtest GTest::gtest_main) +add_test(NAME test_triton_flash_attn_varlen COMMAND test_triton_flash_attn_varlen) diff --git a/ctests/test_triton_flash_attn_varlen.cpp b/ctests/test_triton_flash_attn_varlen.cpp new file mode 100644 index 000000000..1e283a373 --- /dev/null +++ b/ctests/test_triton_flash_attn_varlen.cpp @@ -0,0 +1,388 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "flag_gems/operators.h" +#include "torch/torch.h" + +at::Tensor ref_paged_attn_cpp(const at::Tensor& query, + const at::Tensor& key_cache, + const at::Tensor& value_cache, + const std::vector& query_lens, + const std::vector& kv_lens, + const at::Tensor& block_tables, + double scale, + std::optional attn_bias_opt, + std::optional sliding_window_opt, + std::optional soft_cap_opt) { + // Basic checks (device / dims) + TORCH_CHECK(query.device() == key_cache.device() && query.device() == value_cache.device(), + "All tensors must be on the same device"); + TORCH_CHECK(block_tables.dim() == 2, "block_tables must be 2-D"); + + const auto device = query.device(); + const auto dtype = query.dtype(); // query dtype (float16/bfloat16) + const auto v_dtype = value_cache.dtype(); + + int64_t num_seqs = static_cast(query_lens.size()); + int64_t start_idx = 0; + + // Extract shapes + int64_t num_blocks = key_cache.size(0); + int64_t block_size = key_cache.size(1); + int64_t num_kv_heads = key_cache.size(2); + int64_t head_size = key_cache.size(3); + + std::vector outputs; + outputs.reserve(num_seqs); + + for (int64_t i = 0; i < num_seqs; ++i) { + int64_t query_len = query_lens[i]; + int64_t kv_len = kv_lens[i]; + + at::Tensor q = query.narrow(0, start_idx, query_len).clone(); + // q *= scale + q.mul_(scale); + + // number of kv blocks and indices + int64_t num_kv_blocks = (kv_len + block_size - 1) / block_size; + + at::Tensor block_idx_row = block_tables.index({i}); // shape (max_num_blocks_per_seq,) + at::Tensor block_indices = block_idx_row.narrow(0, 0, num_kv_blocks).to(torch::kLong); + + at::Tensor k_sel = key_cache.index_select(0, block_indices).contiguous(); + k_sel = k_sel.view({-1, num_kv_heads, head_size}); + if (k_sel.size(0) > kv_len) { + k_sel = k_sel.narrow(0, 0, kv_len); + } + + at::Tensor v_sel = value_cache.index_select(0, block_indices).contiguous(); + v_sel = v_sel.view({-1, num_kv_heads, head_size}); + if (v_sel.size(0) > kv_len) { + v_sel = v_sel.narrow(0, 0, kv_len); + } + + int64_t q_heads = q.size(1); + int64_t k_heads = k_sel.size(1); + if (q_heads != k_heads) { + TORCH_CHECK(q_heads % k_heads == 0, "Number of query heads must be a multiple of number of kv heads"); + int64_t repeats = q_heads / k_heads; + k_sel = at::repeat_interleave(k_sel, repeats, /*dim=*/1); + v_sel = at::repeat_interleave(v_sel, repeats, /*dim=*/1); + k_heads = k_sel.size(1); + } + + at::Tensor q_per = q.permute({1, 0, 2}).contiguous(); + at::Tensor k_per = k_sel.permute({1, 0, 2}).contiguous(); + + at::Tensor attn = at::bmm(q_per, k_per.transpose(1, 2)); + + at::Tensor empty_mask = at::ones({query_len, kv_len}, query.options().dtype(torch::kUInt8)).to(device); + int64_t diag = kv_len - query_len + 1; + at::Tensor mask = at::triu(empty_mask, diag).to(at::kBool); + + // sliding window: compute sliding mask and OR it + if (sliding_window_opt.has_value()) { + int64_t sliding = sliding_window_opt.value(); + int64_t diag_sw = kv_len - (query_len + sliding) + 1; + at::Tensor sw_mask = at::triu(empty_mask, diag_sw).to(at::kBool); + // invert (logical_not) + sw_mask = sw_mask.logical_not(); + mask = at::logical_or(mask, sw_mask); + } + + if (soft_cap_opt.has_value()) { + double soft_cap = soft_cap_opt.value(); + at::Tensor attn_fp32 = attn.to(at::kFloat); + attn_fp32 = soft_cap * at::tanh(attn_fp32 / static_cast(soft_cap)); + attn = attn_fp32.to(attn.dtype()); + } + + at::Tensor mask_b = mask.unsqueeze(0); + const float neg_inf = -std::numeric_limits::infinity(); + attn.masked_fill_(mask_b, neg_inf); + + if (attn_bias_opt.has_value()) { + at::Tensor bias_i = attn_bias_opt.value().index({i}); + if (bias_i.size(2) > kv_len) { + bias_i = bias_i.narrow(2, 0, kv_len); + } + // Adjust q dim: broadcast if qb == 1 and query_len > 1 + if (bias_i.size(1) == 1 && query_len > 1) { + bias_i = bias_i.expand({bias_i.size(0), query_len, bias_i.size(2)}); + } else if (bias_i.size(1) > query_len) { + bias_i = bias_i.narrow(1, 0, query_len); + } + attn = attn + bias_i; + } + + attn = at::softmax(attn, -1); + attn = attn.to(v_sel.dtype()); + + at::Tensor v_per = v_sel.permute({1, 0, 2}).contiguous(); + at::Tensor out_per = at::bmm(attn, v_per); + at::Tensor out = out_per.permute({1, 0, 2}).contiguous(); + + outputs.push_back(out); + start_idx += query_len; + } + at::Tensor result = at::cat(outputs, 0); + return result; +} + +// Helper: build attention bias from alibi slopes (parity with Python tests) +at::Tensor attn_bias_from_alibi_slopes_cpp(const at::Tensor& slopes, // (batch, nheads), float32 + int64_t seqlen_q, + int64_t seqlen_k, + bool causal) { + at::Tensor s = slopes.unsqueeze(-1).unsqueeze(-1); + if (causal) { + at::Tensor ar = at::arange(-seqlen_k + 1, 1, slopes.options()); + ar = ar.view({1, 1, 1, seqlen_k}); + return s * ar; + } else { + at::Tensor row_idx = at::arange(seqlen_q, slopes.options().dtype(at::kLong)).unsqueeze(-1); + at::Tensor col_idx = at::arange(seqlen_k, slopes.options().dtype(at::kLong)); + at::Tensor relative_pos = (row_idx + (seqlen_k - seqlen_q) - col_idx).abs().to(at::kFloat); + relative_pos = relative_pos.view({1, 1, seqlen_q, seqlen_k}); + return -s * relative_pos; + } +} + +using VarlenParams = std::tuple, // (num_query_heads, num_kv_heads) + int, // head_size + at::ScalarType, // dtype + bool, // alibi + int, // soft_cap_code: 0->None, 1->10.0, 2->50.0 + int // num_blocks + >; + +class FlashAttnVarlenParamTest : public ::testing::TestWithParam {}; + +TEST_P(FlashAttnVarlenParamTest, MatchesReference) { + torch::manual_seed(1234567890); + const torch::Device device(torch::kCUDA, 0); + + const std::vector> seq_lens = { + { 1, 1328}, + { 5, 18}, + {129, 463} + }; + const int64_t num_seqs = static_cast(seq_lens.size()); + std::vector query_lens, kv_lens; + query_lens.reserve(num_seqs); + kv_lens.reserve(num_seqs); + for (auto& p : seq_lens) { + query_lens.push_back(p.first); + kv_lens.push_back(p.second); + } + const int64_t max_query_len = *std::max_element(query_lens.begin(), query_lens.end()); + const int64_t max_kv_len = *std::max_element(kv_lens.begin(), kv_lens.end()); + + auto [heads, head_size, dtype, alibi, softcap_code, num_blocks] = GetParam(); + const int64_t num_query_heads = heads.first; + const int64_t num_kv_heads = heads.second; + ASSERT_TRUE(num_query_heads % num_kv_heads == 0); + + const int64_t block_size = 32; + const std::optional sliding_window_opt = std::nullopt; + + std::optional soft_cap_opt; + if (softcap_code == 1) + soft_cap_opt = 10.0; + else if (softcap_code == 2) + soft_cap_opt = 50.0; + // if alibi is True and soft_cap is not None + if (alibi && soft_cap_opt.has_value()) { + GTEST_SKIP() << "Skip (alibi + soft_cap)"; + } + + const double scale = 1.0 / std::sqrt(static_cast(head_size)); + + auto opts = torch::TensorOptions().dtype(dtype).device(device); + auto opts_int = torch::TensorOptions().dtype(torch::kInt32).device(device); + + const int64_t total_q_tokens = std::accumulate(query_lens.begin(), query_lens.end(), 0LL); + at::Tensor q = torch::randn({total_q_tokens, num_query_heads, head_size}, opts); + at::Tensor k_cache = torch::randn({num_blocks, block_size, num_kv_heads, head_size}, opts); + at::Tensor v_cache = torch::randn_like(k_cache); + + // cu_seqlens_q + std::vector cu_q_lens_vec; + cu_q_lens_vec.reserve(num_seqs + 1); + cu_q_lens_vec.push_back(0); + for (auto ql : query_lens) cu_q_lens_vec.push_back(static_cast(cu_q_lens_vec.back() + ql)); + at::Tensor cu_seqlens_q = torch::tensor(cu_q_lens_vec, opts_int); + + // seqused_k + at::Tensor seqused_k = torch::tensor(std::vector(kv_lens.begin(), kv_lens.end()), opts_int); + + // block table + const int64_t max_num_blocks_per_seq = (max_kv_len + block_size - 1) / block_size; + at::Tensor block_table = torch::randint(0, num_blocks, {num_seqs, max_num_blocks_per_seq}, opts_int); + + // alibi slopes / bias + std::optional alibi_slopes_opt = std::nullopt; + std::optional attn_bias_opt = std::nullopt; + if (alibi) { + at::Tensor alibi_slopes = torch::ones({num_seqs, num_query_heads}, + torch::TensorOptions().dtype(torch::kFloat32).device(device)) * + 0.3; + alibi_slopes_opt = alibi_slopes; + at::Tensor attn_bias = + attn_bias_from_alibi_slopes_cpp(alibi_slopes, max_query_len, max_kv_len, /*causal=*/true); + attn_bias_opt = attn_bias; + } + + auto out_lse = flag_gems::flash_attn_varlen_func(q, + k_cache, + v_cache, + max_query_len, + cu_seqlens_q, + max_kv_len, + /*cu_seqlens_k*/ std::nullopt, + /*seqused_k*/ seqused_k, + /*q_v*/ std::nullopt, + /*dropout_p*/ 0.0, + /*softmax_scale*/ scale, + /*causal*/ true, + /*window_size_left*/ -1, + /*window_size_right*/ -1, + /*softcap*/ soft_cap_opt.value_or(0.0), + /*alibi_slopes*/ alibi_slopes_opt, + /*deterministic*/ false, + /*return_attn_probs*/ false, + /*block_table*/ block_table, + /*return_softmax_lse*/ false); + at::Tensor op_output = std::get<0>(out_lse); + + // 参考实现 + at::Tensor ref_output = ref_paged_attn_cpp(q, + k_cache, + v_cache, + std::vector(query_lens.begin(), query_lens.end()), + std::vector(kv_lens.begin(), kv_lens.end()), + block_table, + scale, + attn_bias_opt, + sliding_window_opt, + soft_cap_opt); + + EXPECT_TRUE(torch::allclose(op_output, ref_output, /*rtol=*/1e-2, /*atol=*/2e-2)); +} + +INSTANTIATE_TEST_SUITE_P( + FlashAttnVarlenPytestParity_Param, + FlashAttnVarlenParamTest, + ::testing::Combine(::testing::Values(std::make_pair(4, 4), std::make_pair(8, 2), std::make_pair(16, 2)), + ::testing::Values(128, 192, 256), + ::testing::Values(at::kHalf, at::kBFloat16), + ::testing::Values(false, true), + ::testing::Values(0, 1, 2), // soft_cap: None / 10.0 / 50.0 + ::testing::Values(32768, 2048))); + +using SwapParams = std::tupleNone, 1->10.0 + >; + +class FlashAttnVarlenSwapQGParamTest : public ::testing::TestWithParam {}; + +TEST_P(FlashAttnVarlenSwapQGParamTest, MatchesReference) { + torch::manual_seed(1234567890); + const torch::Device device(torch::kCUDA, 0); + + const std::vector> seq_lens = { + {1, 1328}, + {1, 18}, + {1, 463} + }; + const int64_t num_seqs = static_cast(seq_lens.size()); + std::vector query_lens, kv_lens; + query_lens.reserve(num_seqs); + kv_lens.reserve(num_seqs); + for (auto& p : seq_lens) { + query_lens.push_back(p.first); + kv_lens.push_back(p.second); + } + const int64_t max_query_len = *std::max_element(query_lens.begin(), query_lens.end()); + const int64_t max_kv_len = *std::max_element(kv_lens.begin(), kv_lens.end()); + + const int64_t num_query_heads = 8; + const int64_t num_kv_heads = 2; + const int64_t head_size = 128; + const int64_t block_size = 32; + const int64_t num_blocks = 2048; + + auto [dtype, softcap_code] = GetParam(); + std::optional soft_cap_opt; + if (softcap_code == 1) soft_cap_opt = 10.0; + + const double scale = 1.0 / std::sqrt(static_cast(head_size)); + + auto opts = torch::TensorOptions().dtype(dtype).device(device); + auto opts_int = torch::TensorOptions().dtype(torch::kInt32).device(device); + + // 构造输入 + const int64_t total_q_tokens = std::accumulate(query_lens.begin(), query_lens.end(), 0LL); + at::Tensor q = torch::randn({total_q_tokens, num_query_heads, head_size}, opts); + at::Tensor k_cache = torch::randn({num_blocks, block_size, num_kv_heads, head_size}, opts); + at::Tensor v_cache = torch::randn_like(k_cache); + + std::vector cu_q_lens_vec; + cu_q_lens_vec.reserve(num_seqs + 1); + cu_q_lens_vec.push_back(0); + for (auto ql : query_lens) cu_q_lens_vec.push_back(static_cast(cu_q_lens_vec.back() + ql)); + at::Tensor cu_seqlens_q = torch::tensor(cu_q_lens_vec, opts_int); + + at::Tensor seqused_k = torch::tensor(std::vector(kv_lens.begin(), kv_lens.end()), opts_int); + const int64_t max_num_blocks_per_seq = (max_kv_len + block_size - 1) / block_size; + at::Tensor block_table = torch::randint(0, num_blocks, {num_seqs, max_num_blocks_per_seq}, opts_int); + + auto out_lse = flag_gems::flash_attn_varlen_func(q, + k_cache, + v_cache, + max_query_len, + cu_seqlens_q, + max_kv_len, + /*cu_seqlens_k*/ std::nullopt, + /*seqused_k*/ seqused_k, + /*q_v*/ std::nullopt, + /*dropout_p*/ 0.0, + /*softmax_scale*/ scale, + /*causal*/ true, + /*window_size_left*/ -1, + /*window_size_right*/ -1, + /*softcap*/ soft_cap_opt.value_or(0.0), + /*alibi_slopes*/ std::nullopt, + /*deterministic*/ false, + /*return_attn_probs*/ false, + /*block_table*/ block_table, + /*return_softmax_lse*/ false); + at::Tensor op_output = std::get<0>(out_lse); + + at::Tensor ref_output = ref_paged_attn_cpp(q, + k_cache, + v_cache, + std::vector(query_lens.begin(), query_lens.end()), + std::vector(kv_lens.begin(), kv_lens.end()), + block_table, + scale, + /*attn_bias*/ std::nullopt, + /*sliding_window*/ std::nullopt, + soft_cap_opt); + + EXPECT_TRUE(torch::allclose(op_output, ref_output, /*rtol=*/1e-2, /*atol=*/2e-2)); +} + +INSTANTIATE_TEST_SUITE_P(FlashAttnVarlenPytestParity_Param_Swap, + FlashAttnVarlenSwapQGParamTest, + ::testing::Combine(::testing::Values(at::kHalf, at::kBFloat16), + ::testing::Values(0, 1) // soft_cap: None / 10.0 + )); diff --git a/include/flag_gems/operators.h b/include/flag_gems/operators.h index 5d13d0c62..dc7bf8578 100644 --- a/include/flag_gems/operators.h +++ b/include/flag_gems/operators.h @@ -87,4 +87,104 @@ void reshape_and_cache_flash(const at::Tensor &key, const std::string &kv_cache_dtype, const std::optional &k_scale, const std::optional &v_scale); + +std::tuple flash_attn_varlen_func( + const at::Tensor &q, + const at::Tensor &k, + const at::Tensor &v, + int64_t max_seqlen_q, + const at::Tensor &cu_seqlens_q, + int64_t max_seqlen_k, + const std::optional &cu_seqlens_k = std::nullopt, + const std::optional &seqused_k = std::nullopt, + const std::optional &q_v = std::nullopt, + double dropout_p = 0.0, + const std::optional &softmax_scale = std::nullopt, + bool causal = false, + // CHANGE: window_size is now two separate integers + int64_t window_size_left = -1, + int64_t window_size_right = -1, + double softcap = 0.0, + const std::optional &alibi_slopes = std::nullopt, + bool deterministic = false, + bool return_attn_probs = false, + const std::optional &block_table = std::nullopt, + bool return_softmax_lse = false, + const std::optional &out = std::nullopt, + const std::optional &scheduler_metadata = std::nullopt, + const std::optional &q_descale = std::nullopt, + const std::optional &k_descale = std::nullopt, + const std::optional &v_descale = std::nullopt, + int64_t num_splits = 0, + int64_t fa_version = 2); + +struct FlashFwdParams { + // tensor pointers + at::Tensor q; + at::Tensor k; + at::Tensor v; + at::Tensor out; + at::Tensor p; + at::Tensor lse; + // strides + int64_t q_row_stride; + int64_t k_row_stride; + int64_t v_row_stride; + int64_t q_head_stride; + int64_t k_head_stride; + int64_t v_head_stride; + int64_t o_row_stride; + int64_t o_head_stride; + // batch strides + int64_t q_batch_stride; + int64_t k_batch_stride; + int64_t v_batch_stride; + int64_t o_batch_stride; + // cu_seqlens / seqused_k flags & tensors + bool is_cu_seqlens_q; + at::Tensor cu_seqlens_q; + bool is_cu_seqlens_k; + at::Tensor cu_seqlens_k; + bool is_seqused_k; + at::Tensor seqused_k; + // sizes + int64_t batch_size; + int64_t k_batch_size; + int64_t num_heads; + int64_t num_heads_k; + int64_t h_hk_ratio; + int64_t seqlen_q; + int64_t seqlen_k; + int64_t seqlen_q_rounded; + int64_t seqlen_k_rounded; + int64_t head_size; + int64_t head_size_rounded; + // scaling factors + bool is_softcap; + double softcap; + double scale_softmax; + double scale_softmax_log2e; + // dropout + bool is_dropout; + double p_dropout; + double rp_dropout; + int64_t p_dropout_in_uint8_t; + at::Tensor philox_args; + bool return_softmax; + // causal & sliding window attention + bool is_causal; + bool is_local; + int64_t window_size_left; + int64_t window_size_right; + bool seqlenq_ngroups_swapped; + // alibi + bool is_alibi; + at::Tensor alibi_slopes; + int64_t alibi_slopes_batch_stride; + // block table params + int64_t total_q; + at::Tensor page_table; + int64_t page_table_batch_stride; + int64_t block_size; +}; } // namespace flag_gems diff --git a/include/flag_gems/utils.h b/include/flag_gems/utils.h index baa6d62ce..0c819a950 100644 --- a/include/flag_gems/utils.h +++ b/include/flag_gems/utils.h @@ -1,4 +1,6 @@ #pragma once +#include +#include #include // dladdr #include #include @@ -8,6 +10,7 @@ #include #include #include +#include #include "torch/torch.h" namespace flag_gems::utils { @@ -24,3 +27,9 @@ std::tuple permute_reduction_axes_right( std::tuple parse_reduction_axes(const at::Tensor &tensor, int reduction_axis); int cdiv(int a, int b); } // namespace flag_gems::utils + +// Public declaration for reuse (implementation in lib/exponential_.cpp) +namespace flag_gems { +std::pair philox_backend_seed_offset(int64_t increment, + c10::optional generator_opt); +} diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 078613bdf..9b56bd5ce 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -21,7 +21,8 @@ add_library(operators softmax.cpp exponential_.cpp fill.cpp - reshape_and_cache_flash.cpp) + reshape_and_cache_flash.cpp + flash_attn_varlen_func.cpp) target_include_directories(operators PUBLIC $ diff --git a/lib/flash_attn_varlen_func.cpp b/lib/flash_attn_varlen_func.cpp new file mode 100644 index 000000000..50d697180 --- /dev/null +++ b/lib/flash_attn_varlen_func.cpp @@ -0,0 +1,544 @@ +#include +#include +#include +#include +#include "c10/cuda/CUDAStream.h" +#include "flag_gems/operators.h" +#include "flag_gems/utils.h" +#include "torch/torch.h" +#include "triton_jit/triton_jit_function.h" + +namespace { +std::tuple +mha_varlan_fwd_internal(const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + const at::Tensor& out, + const at::Tensor& cu_seqlens_q, + const at::Tensor& cu_seqlens_k, + const at::Tensor& seqused_k, + const at::Tensor& leftpad_k, + const at::Tensor& page_table, + const at::Tensor& alibi_slopes, + int64_t max_seqlen_q, + int64_t max_seqlen_k, + double p_dropout, + double softmax_scale, + bool zero_tensors, + bool is_causal, + int64_t window_size_left, + int64_t window_size_right, + double softcap, + bool return_softmax, + const at::Tensor& gen) { + // 253-301 + TORCH_CHECK(q.device() == k.device() && k.device() == v.device(), "q, k, v must be on the same device"); + auto q_device = q.device(); + auto q_dtype = q.scalar_type(); + TORCH_CHECK(q_dtype == at::kHalf || q_dtype == at::kBFloat16, + "FlashAttention only support fp16 and bf16 data type"); + TORCH_CHECK(q.scalar_type() == k.scalar_type() && q.scalar_type() == v.scalar_type(), + "q, k, v must have the same data type"); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + TORCH_CHECK(cu_seqlens_q.scalar_type() == at::kInt, "cu_seqlens_q must be int32"); + TORCH_CHECK(cu_seqlens_q.is_contiguous(), "cu_seqlens_q must be contiguous"); + + TORCH_CHECK(cu_seqlens_k.scalar_type() == at::kInt, "cu_seqlens_k must be int32"); + TORCH_CHECK(cu_seqlens_k.is_contiguous(), "cu_seqlens_k must be contiguous"); + + TORCH_CHECK(page_table.defined(), "page_table must be provided"); + // # q shape: [total_q_tokens, num_heads, head_size] + // # k shape: + // # paged_kv: [num_pages, block_size, num_heads_k, head_size] + // # batch_size, number of sentences + + const auto total_q = q.size(0); + const auto num_heads = q.size(1); + const auto head_size = q.size(2); + + const auto num_heads_k = k.size(2); + const auto batch_size = cu_seqlens_q.numel() - 1; + const auto block_size = k.size(1); + const auto num_pages = k.size(0); + const auto k_batch_size = num_pages; + + const auto page_table_batch_stride = page_table.stride(0); + TORCH_CHECK(k.sizes() == v.sizes(), "k and v must equal size"); + TORCH_CHECK(cu_seqlens_q.numel() == (batch_size + 1), "cu_seqlens_q must equal size batch_size + 1"); + TORCH_CHECK(cu_seqlens_k.numel() == (batch_size + 1), "cu_seqlens_k must equal size batch_size + 1"); + // Check output shape + if (out.defined()) { + TORCH_CHECK(out.stride(out.dim() - 1) == 1, "Output tensor must have contiguous last dimension"); + TORCH_CHECK(out.scalar_type() == q_dtype, "Output tensor must have the same dtype as input"); + TORCH_CHECK(out.size(0) == total_q && out.size(1) == num_heads && out.size(2) == head_size, + "Output tensor has incorrect shape"); + } + if (seqused_k.defined()) { + TORCH_CHECK(seqused_k.is_contiguous(), "seqused_k must be contiguous"); + TORCH_CHECK(seqused_k.numel() == batch_size, "seqused_k has incorrect size"); + } + + auto q_final = q; + auto cu_seqlens_q_final = cu_seqlens_q; + auto num_heads_final = num_heads; + auto total_q_final = total_q; + int64_t q_batch_stride = 0; + auto k_batch_stride = k.stride(0); + auto v_batch_stride = v.stride(0); + int64_t o_batch_stride = 0; + + bool is_local = false; + auto final_is_causal = is_causal; + if (max_seqlen_q == 1 && !alibi_slopes.defined()) { + final_is_causal = false; + } + if (final_is_causal) { + window_size_right = 0; + } + // check disable swa + if (window_size_left >= max_seqlen_k) { + window_size_left = -1; + } + if (window_size_right >= max_seqlen_k) { + window_size_right = -1; + } + is_local = window_size_left >= 0; + + // Optimize all single-query sequences by swapping the query-group and sequence dimensions + // Reshape tensor to align Q heads count with K heads count. + auto seqlenq_ngroups_swapped = + (max_seqlen_q == 1 && !alibi_slopes.defined() && num_heads_final > num_heads_k && + window_size_left < 0 && window_size_right < 0 && p_dropout == 0); + auto q_groups = num_heads_final / num_heads_k; + if (seqlenq_ngroups_swapped) { + q_final = q.reshape({batch_size, num_heads_k, q_groups, head_size}) + .transpose(1, 2) + .reshape({batch_size * q_groups, num_heads_k, head_size}); + max_seqlen_q = q_groups; + num_heads_final = num_heads_k; + cu_seqlens_q_final = at::Tensor(); + + // q.stride(0) * max_seqlen_q = + // = (num_heads_k * head_size) * q_groups + // = (num_heads_k * q_groups) * head_size + // = num_heads * head_size + q_batch_stride = q_final.stride(0) * max_seqlen_q; + k_batch_stride = k.stride(0); + v_batch_stride = v.stride(0); + } else { + q_batch_stride = 0; + k_batch_stride = 0; + v_batch_stride = 0; + o_batch_stride = 0; + } + total_q_final = q_final.size(0); + TORCH_CHECK(!leftpad_k.defined(), "leftpad_k is not supported."); + TORCH_CHECK(head_size <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(head_size % 8 == 0, "head_size must be a multiple of 8, this is ensured by padding!"); + TORCH_CHECK(num_heads_final % num_heads_k == 0, + "Number of heads in key/value must divide number of heads in query"); + TORCH_CHECK(q_final.sizes() == c10::IntArrayRef({total_q_final, num_heads_final, head_size}), + "q sizes check failed"); + TORCH_CHECK(k.sizes() == c10::IntArrayRef({num_pages, block_size, num_heads_k, head_size}), + "k sizes check failed"); + TORCH_CHECK(v.sizes() == c10::IntArrayRef({num_pages, block_size, num_heads_k, head_size}), + "v sizes check failed"); + TORCH_CHECK(k.strides() == v.strides(), "k and v must have the same stride"); + + if (softcap > 0.0) { + TORCH_CHECK(p_dropout == 0, "dropout is not supported if softcap is used."); + } + // data preprocess and alignment + auto round_multiple = [](int64_t x, int64_t m) { return (x + m - 1) / m * m; }; + auto head_size_rounded = head_size < 192 ? round_multiple(head_size, 32) : 256; + auto seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + auto seqlen_k_rounded = round_multiple(max_seqlen_k, 32); + + constexpr double LOG2E = 1.4426950408889634074; + bool is_softcap = false; + double adjusted_scale_softmax = 0.0; + double adjusted_softcap = 0.0; + double adjusted_scale_softmax_log2e = 0.0; + if (softcap > 0.0) { + is_softcap = true; + adjusted_scale_softmax = softcap; + adjusted_softcap = softmax_scale / softcap; + adjusted_scale_softmax_log2e = softcap * LOG2E; + } else { + is_softcap = false; + adjusted_softcap = 0.0; + adjusted_scale_softmax = softmax_scale; + adjusted_scale_softmax_log2e = softmax_scale * LOG2E; + } + // Set alibi params + bool is_alibi = false; + int64_t alibi_slopes_batch_stride = 0; + if (alibi_slopes.defined()) { + TORCH_CHECK(alibi_slopes.device() == q_device); + TORCH_CHECK(alibi_slopes.scalar_type() == at::kFloat); + TORCH_CHECK(alibi_slopes.stride(alibi_slopes.dim() - 1) == 1); + TORCH_CHECK(alibi_slopes.sizes() == c10::IntArrayRef({ + num_heads_final, + }) || + alibi_slopes.sizes() == c10::IntArrayRef({ + batch_size, + num_heads_final, + })); + alibi_slopes_batch_stride = alibi_slopes.dim() == 2 ? alibi_slopes.stride(0) : 0; + is_alibi = true; + } else { + alibi_slopes_batch_stride = 0; + is_alibi = false; + } + // Prepare params to kernel + at::Tensor out_final = out; + at::Tensor out_ = at::Tensor(); + at::Tensor lse; + at::Tensor philox_args; + at::Tensor p; + at::Tensor unused; // optional, may remain undefined when not used + { + const c10::DeviceGuard guard(q_device); + if (out.defined()) { + out_ = out; + if (seqlenq_ngroups_swapped) + out_final = at::empty_like(q_final, q_final.options().dtype(v.scalar_type())); + } else { + out_ = at::Tensor(); + out_final = at::empty_like(q_final, q_final.options().dtype(v.scalar_type())); + } + if (seqlenq_ngroups_swapped) o_batch_stride = out_final.stride(0) * max_seqlen_q; + lse = at::empty({num_heads_final, total_q_final}, at::TensorOptions().dtype(at::kFloat).device(q_device)); + + bool is_dropout = false; + int64_t increment = 0, philox_seed = 0, philox_offset = 0; + philox_args = at::Tensor(); + // Inference + if (p_dropout > 0) { + is_dropout = true; + increment = batch_size * num_heads_final * 32; + + auto [seed, offset] = flag_gems::philox_backend_seed_offset(increment, c10::nullopt); + philox_seed = seed; + philox_offset = offset; + philox_args = + at::tensor({philox_seed, philox_offset}, at::TensorOptions().dtype(at::kLong).device(q_device)); + } else { + is_dropout = false; + philox_args = at::empty({2}, at::TensorOptions().dtype(at::kLong).device(q_device)); + } + p_dropout = 1.0 - p_dropout; + int64_t p_dropout_in_uint8_t = static_cast(std::floor(p_dropout * 255.0)); + double rp_dropout = 1.0 / p_dropout; + if (return_softmax) { + TORCH_CHECK(is_dropout, "Only supported with non-zero dropout."); + p = at::empty({batch_size, num_heads_final, seqlen_q_rounded, seqlen_k_rounded}, + at::TensorOptions().device(q_device)); + } else { + p = at::empty({}, at::TensorOptions().device(q_device)); + } + if (zero_tensors) { + out_final.zero_(); + lse.fill_(-std::numeric_limits::infinity()); + } + + TORCH_CHECK(q_final.dim() >= 3, "q_final must be at least 3D", q_final.dim()); + TORCH_CHECK(k.dim() >= 3, "k must be at least 3D", k.dim()); + TORCH_CHECK(v.dim() >= 3, "v must be at least 3D", v.dim()); + TORCH_CHECK(out_final.dim() >= 3, "out_final must be at least 3D", out_final.dim()); + const int64_t q_row_stride = q_final.stride(q_final.dim() - 3); + const int64_t k_row_stride = k.stride(k.dim() - 3); + const int64_t v_row_stride = v.stride(v.dim() - 3); + const int64_t q_head_stride = q_final.stride(q_final.dim() - 2); + const int64_t k_head_stride = k.stride(k.dim() - 2); + const int64_t v_head_stride = v.stride(v.dim() - 2); + const int64_t o_row_stride = out_final.stride(out_final.dim() - 3); + const int64_t o_head_stride = out_final.stride(out_final.dim() - 2); + + // Prepare safe placeholders for optional tensors to ensure they have storage + const bool is_cu_seqlens_q_flag = cu_seqlens_q_final.defined(); + const bool is_seqused_k_flag = seqused_k.defined(); + const bool is_cu_seqlens_k_flag = !is_seqused_k_flag; + + at::Tensor cu_seqlens_q_safe = is_cu_seqlens_q_flag + ? cu_seqlens_q_final + : at::empty({1}, at::TensorOptions().dtype(at::kInt).device(q_device)); + at::Tensor cu_seqlens_k_safe = cu_seqlens_k; + at::Tensor seqused_k_safe = + is_seqused_k_flag ? seqused_k + : at::empty({batch_size}, at::TensorOptions().dtype(at::kInt).device(q_device)); + at::Tensor alibi_slopes_safe = + is_alibi ? alibi_slopes : at::empty({1}, at::TensorOptions().dtype(at::kFloat).device(q_device)); + + flag_gems::FlashFwdParams params; + params.q = q_final; + params.k = k; + params.v = v; + params.out = out_final; + params.p = p; + params.lse = lse; + // strides + params.q_row_stride = q_row_stride; + params.k_row_stride = k_row_stride; + params.v_row_stride = v_row_stride; + params.q_head_stride = q_head_stride; + params.k_head_stride = k_head_stride; + params.v_head_stride = v_head_stride; + params.o_row_stride = o_row_stride; + params.o_head_stride = o_head_stride; + params.q_batch_stride = q_batch_stride; + params.k_batch_stride = k_batch_stride; + params.v_batch_stride = v_batch_stride; + params.o_batch_stride = o_batch_stride; + params.is_cu_seqlens_q = is_cu_seqlens_q_flag; + params.cu_seqlens_q = cu_seqlens_q_safe; + params.is_cu_seqlens_k = is_cu_seqlens_k_flag; + params.cu_seqlens_k = cu_seqlens_k_safe; + params.is_seqused_k = is_seqused_k_flag; + params.seqused_k = seqused_k_safe; + params.batch_size = batch_size; + params.k_batch_size = k_batch_size; + params.num_heads = num_heads_final; + params.num_heads_k = num_heads_k; + params.h_hk_ratio = num_heads_final / num_heads_k; + params.seqlen_q = max_seqlen_q; + params.seqlen_k = max_seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.head_size = head_size; + params.head_size_rounded = head_size_rounded; + // scaling(softcap/softmax) + params.is_softcap = is_softcap; + params.softcap = adjusted_softcap; + params.scale_softmax = adjusted_scale_softmax; + params.scale_softmax_log2e = adjusted_scale_softmax_log2e; + // dropout + params.is_dropout = is_dropout; + params.p_dropout = p_dropout; + params.rp_dropout = rp_dropout; + params.p_dropout_in_uint8_t = p_dropout_in_uint8_t; + params.philox_args = philox_args; + params.return_softmax = return_softmax; + params.is_causal = final_is_causal; + params.is_local = is_local; + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; + // alibi + params.is_alibi = is_alibi; + params.alibi_slopes = alibi_slopes_safe; + params.alibi_slopes_batch_stride = alibi_slopes_batch_stride; + // block table + params.total_q = total_q_final; + params.page_table = page_table; + params.page_table_batch_stride = page_table_batch_stride; + params.block_size = block_size; + + const double avg_seqlen_q = static_cast(total_q_final) / static_cast(batch_size); + int64_t BLOCK_M = (avg_seqlen_q >= 256) ? 128 : 32; // prefill or decode + int64_t BLOCK_N = 32; + int64_t num_warps = 4; + int64_t num_stages = 3; + + const unsigned grid_x = static_cast(flag_gems::utils::cdiv(max_seqlen_q, BLOCK_M)); + const unsigned grid_y = static_cast(batch_size); + const unsigned grid_z = static_cast(num_heads_final); + + const triton_jit::TritonJITFunction& f = triton_jit::TritonJITFunction::get_instance( + (flag_gems::utils::get_flag_gems_src_path() / "ops" / "flash_kernel.py").string(), + "flash_varlen_fwd_kernel"); + c10::cuda::CUDAStream stream = c10::cuda::getCurrentCUDAStream(); + CUstream raw_stream = static_cast(stream.stream()); + + f(raw_stream, + grid_x, + grid_y, + grid_z, + static_cast(num_warps), + static_cast(num_stages), + params.q, + params.k, + params.v, + params.out, + params.p, + params.lse, + params.q_row_stride, + params.k_row_stride, + params.v_row_stride, + params.q_head_stride, + params.k_head_stride, + params.v_head_stride, + params.o_row_stride, + params.o_head_stride, + params.q_batch_stride, + params.k_batch_stride, + params.v_batch_stride, + params.o_batch_stride, + params.is_cu_seqlens_q, + params.cu_seqlens_q, + params.is_cu_seqlens_k, + params.cu_seqlens_k, + params.is_seqused_k, + params.seqused_k, + // sizes + params.batch_size, + params.k_batch_size, + params.num_heads, + params.num_heads_k, + params.h_hk_ratio, + params.seqlen_q, + params.seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.head_size, + params.head_size_rounded, + // scaling + params.is_softcap, + params.softcap, + params.scale_softmax, + params.scale_softmax_log2e, + // dropout + params.is_dropout, + params.p_dropout, + params.rp_dropout, + params.p_dropout_in_uint8_t, + params.philox_args, + params.return_softmax, + // causal / local / windows / swap + params.is_causal, + params.is_local, + params.window_size_left, + params.window_size_right, + params.seqlenq_ngroups_swapped, + // alibi + params.is_alibi, + params.alibi_slopes, + params.alibi_slopes_batch_stride, + // block table + params.total_q, + params.page_table, + params.page_table_batch_stride, + params.block_size, + // kernel compile-time config + BLOCK_M, + BLOCK_N, + params.head_size_rounded, + num_warps, + num_stages); + + if (seqlenq_ngroups_swapped) { + at::Tensor out_swapped = + out_final.reshape({batch_size, max_seqlen_q, num_heads_k, head_size}).transpose(1, 2); + if (out_.defined()) { + at::Tensor out_view = out_.view({batch_size, num_heads_k, max_seqlen_q, head_size}); + out_view.copy_(out_swapped); + out_final = out_; + } else { + out_final = out_swapped.reshape({batch_size, num_heads_k * max_seqlen_q, head_size}); + } + lse = lse.reshape({num_heads_k, batch_size, max_seqlen_q}) + .reshape({num_heads_k * max_seqlen_q, batch_size}); + // mark unused only when swap path is taken (optional) + unused = at::empty({}, at::TensorOptions().dtype(at::kLong).device(q_device)); + } + } + return std::make_tuple(out_final, q_final, k, v, lse, philox_args, unused, p); +} +} // namespace + +namespace flag_gems { +using namespace triton_jit; + +std::tuple flash_attn_varlen_func(const at::Tensor& q, + const at::Tensor& k, + const at::Tensor& v, + int64_t max_seqlen_q, + const at::Tensor& cu_seqlens_q, + int64_t max_seqlen_k, + const std::optional& cu_seqlens_k, + const std::optional& seqused_k, + const std::optional& q_v, + double dropout_p, + const std::optional& softmax_scale, + bool causal, + int64_t window_size_left, + int64_t window_size_right, + double softcap, + const std::optional& alibi_slopes, + bool deterministic, + bool return_attn_probs, + const std::optional& block_table, + bool return_softmax_lse, + const std::optional& out, + const std::optional& scheduler_metadata, + const std::optional& q_descale, + const std::optional& k_descale, + const std::optional& v_descale, + int64_t num_splits, + int64_t fa_version) { + TORCH_CHECK(cu_seqlens_k.has_value() || seqused_k.has_value(), + "cu_seqlens_k or seqused_k must be provided"); + TORCH_CHECK(!(cu_seqlens_k.has_value() && seqused_k.has_value()), + "cu_seqlens_k and seqused_k cannot be provided at the same time"); + TORCH_CHECK(!block_table.has_value() || seqused_k.has_value(), + "seqused_k must be provided if block_table is provided"); + + double softmax_scale_val; + if (!softmax_scale.has_value()) { + softmax_scale_val = 1.0 / std::sqrt(q.size(q.dim() - 1)); + } else { + softmax_scale_val = softmax_scale.value(); + } + // window_size has handled by direct parameters + auto q_cont = q.contiguous(); + auto k_cont = k.contiguous(); + auto v_cont = v.contiguous(); + + at::Tensor dummy_cu_seqlens_k; + if (!cu_seqlens_k.has_value()) { + dummy_cu_seqlens_k = at::empty_like(cu_seqlens_q); + } + const at::Tensor& cu_seqlens_k_ref = cu_seqlens_k.has_value() ? cu_seqlens_k.value() : dummy_cu_seqlens_k; + + TORCH_CHECK(fa_version == 2, "Only FA2 is implemented"); + TORCH_CHECK(num_splits == 0, "num_splits > 0 is not implemented in GEMS."); + + const at::Tensor empty_undefined = at::Tensor(); + const at::Tensor& seqused_k_ref = seqused_k.has_value() ? seqused_k.value() : empty_undefined; + const at::Tensor& block_table_ref = block_table.has_value() ? block_table.value() : empty_undefined; + const at::Tensor& alibi_slopes_ref = alibi_slopes.has_value() ? alibi_slopes.value() : empty_undefined; + const at::Tensor& out_ref = out.has_value() ? out.value() : empty_undefined; + + auto outputs = mha_varlan_fwd_internal(q_cont, + k_cont, + v_cont, + out_ref, + cu_seqlens_q, + cu_seqlens_k_ref, + seqused_k_ref, + empty_undefined, + block_table_ref, + alibi_slopes_ref, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale_val, + false, + causal, + window_size_left, + window_size_right, + softcap, + return_softmax_lse && dropout_p > 0.0, + empty_undefined // gen + ); + + auto out_tensor = std::get<0>(outputs); + auto softmax_lse = std::get<4>(outputs); + return std::make_tuple(out_tensor, softmax_lse); +} + +} // namespace flag_gems diff --git a/src/flag_gems/csrc/cstub.cpp b/src/flag_gems/csrc/cstub.cpp index 724778535..f38880e41 100644 --- a/src/flag_gems/csrc/cstub.cpp +++ b/src/flag_gems/csrc/cstub.cpp @@ -70,6 +70,18 @@ TORCH_LIBRARY(flag_gems, m) { "reshape_and_cache_flash(Tensor key, Tensor value, Tensor(a!) key_cache, Tensor(b!) value_cache, " "Tensor slot_mapping, str kv_cache_dtype, Tensor? k_scale=None, Tensor? v_scale=None) -> " "()"); + m.def( + "flash_attn_varlen_func(Tensor q, Tensor k, Tensor v, SymInt max_seqlen_q, Tensor cu_seqlens_q, SymInt " + "max_seqlen_k, " + "Tensor? cu_seqlens_k=None, Tensor? seqused_k=None, Tensor? q_v=None, float dropout_p=0.0, float? " + "softmax_scale=None, " + "bool causal=False, SymInt window_size_left=-1, SymInt window_size_right=-1, float softcap=0.0, " + "Tensor? alibi_slopes=None, " + "bool deterministic=False, bool return_attn_probs=False, Tensor? block_table=None, bool " + "return_softmax_lse=False, " + "Tensor? out=None, Tensor? scheduler_metadata=None, float? q_descale=None, float? k_descale=None, " + "float? v_descale=None, " + "SymInt num_splits=0, SymInt fa_version=2) -> (Tensor, Tensor)"); } TORCH_LIBRARY_IMPL(flag_gems, CUDA, m) { @@ -107,5 +119,6 @@ TORCH_LIBRARY_IMPL(flag_gems, CUDA, m) { m.impl("softmax", TORCH_FN(softmax)); m.impl("softmax_backward", TORCH_FN(softmax_backward)); m.impl("reshape_and_cache_flash", TORCH_FN(reshape_and_cache_flash)); + m.impl("flash_attn_varlen_func", TORCH_FN(flash_attn_varlen_func)); } } // namespace flag_gems