Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ctests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,8 @@ add_executable(test_triton_reshape_and_cache_flash test_triton_reshape_and_cache
target_link_libraries(test_triton_reshape_and_cache_flash
PRIVATE Torch::Torch operators GTest::gtest GTest::gtest_main)
add_test(NAME test_triton_reshape_and_cache_flash COMMAND test_triton_reshape_and_cache_flash)

add_executable(test_triton_flash_attn_varlen test_triton_flash_attn_varlen.cpp)
target_link_libraries(test_triton_flash_attn_varlen
PRIVATE Torch::Torch operators GTest::gtest GTest::gtest_main)
add_test(NAME test_triton_flash_attn_varlen COMMAND test_triton_flash_attn_varlen)
388 changes: 388 additions & 0 deletions ctests/test_triton_flash_attn_varlen.cpp

Large diffs are not rendered by default.

100 changes: 100 additions & 0 deletions include/flag_gems/operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,104 @@ void reshape_and_cache_flash(const at::Tensor &key,
const std::string &kv_cache_dtype,
const std::optional<at::Tensor> &k_scale,
const std::optional<at::Tensor> &v_scale);

std::tuple<at::Tensor, at::Tensor> flash_attn_varlen_func(
const at::Tensor &q,
const at::Tensor &k,
const at::Tensor &v,
int64_t max_seqlen_q,
const at::Tensor &cu_seqlens_q,
int64_t max_seqlen_k,
const std::optional<at::Tensor> &cu_seqlens_k = std::nullopt,
const std::optional<at::Tensor> &seqused_k = std::nullopt,
const std::optional<at::Tensor> &q_v = std::nullopt,
double dropout_p = 0.0,
const std::optional<double> &softmax_scale = std::nullopt,
bool causal = false,
// CHANGE: window_size is now two separate integers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment appears to be a leftover from development and does not serve as documentation. It should be removed to keep the header file clean.

Suggested change
// CHANGE: window_size is now two separate integers
// window_size is now two separate integers

int64_t window_size_left = -1,
int64_t window_size_right = -1,
double softcap = 0.0,
const std::optional<at::Tensor> &alibi_slopes = std::nullopt,
bool deterministic = false,
bool return_attn_probs = false,
const std::optional<at::Tensor> &block_table = std::nullopt,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter is named block_table here and in the PyTorch library definition, but it's referred to as page_table in the internal implementation (mha_varlan_fwd_internal and FlashFwdParams). For consistency and clarity, it would be better to use the same name throughout. page_table is arguably more descriptive in the context of paged attention. I recommend renaming block_table to page_table in the public API. This change should be propagated to cstub.cpp, lib/flash_attn_varlen_func.cpp, and ctests/test_triton_flash_attn_varlen.cpp for consistency.

Suggested change
const std::optional<at::Tensor> &block_table = std::nullopt,
const std::optional<at::Tensor> &page_table = std::nullopt,

bool return_softmax_lse = false,
const std::optional<at::Tensor> &out = std::nullopt,
const std::optional<at::Tensor> &scheduler_metadata = std::nullopt,
const std::optional<double> &q_descale = std::nullopt,
const std::optional<double> &k_descale = std::nullopt,
const std::optional<double> &v_descale = std::nullopt,
int64_t num_splits = 0,
int64_t fa_version = 2);

struct FlashFwdParams {
// tensor pointers
at::Tensor q;
at::Tensor k;
at::Tensor v;
at::Tensor out;
at::Tensor p;
at::Tensor lse;
// strides
int64_t q_row_stride;
int64_t k_row_stride;
int64_t v_row_stride;
int64_t q_head_stride;
int64_t k_head_stride;
int64_t v_head_stride;
int64_t o_row_stride;
int64_t o_head_stride;
// batch strides
int64_t q_batch_stride;
int64_t k_batch_stride;
int64_t v_batch_stride;
int64_t o_batch_stride;
// cu_seqlens / seqused_k flags & tensors
bool is_cu_seqlens_q;
at::Tensor cu_seqlens_q;
bool is_cu_seqlens_k;
at::Tensor cu_seqlens_k;
bool is_seqused_k;
at::Tensor seqused_k;
// sizes
int64_t batch_size;
int64_t k_batch_size;
int64_t num_heads;
int64_t num_heads_k;
int64_t h_hk_ratio;
int64_t seqlen_q;
int64_t seqlen_k;
int64_t seqlen_q_rounded;
int64_t seqlen_k_rounded;
int64_t head_size;
int64_t head_size_rounded;
// scaling factors
bool is_softcap;
double softcap;
double scale_softmax;
double scale_softmax_log2e;
// dropout
bool is_dropout;
double p_dropout;
double rp_dropout;
int64_t p_dropout_in_uint8_t;
at::Tensor philox_args;
bool return_softmax;
// causal & sliding window attention
bool is_causal;
bool is_local;
int64_t window_size_left;
int64_t window_size_right;
bool seqlenq_ngroups_swapped;
// alibi
bool is_alibi;
at::Tensor alibi_slopes;
int64_t alibi_slopes_batch_stride;
// block table params
int64_t total_q;
at::Tensor page_table;
int64_t page_table_batch_stride;
int64_t block_size;
};
} // namespace flag_gems
9 changes: 9 additions & 0 deletions include/flag_gems/utils.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#pragma once
#include <ATen/Generator.h>
#include <c10/util/Optional.h>
#include <dlfcn.h> // dladdr
#include <array>
#include <cstdio>
Expand All @@ -8,6 +10,7 @@
#include <memory>
#include <stdexcept>
#include <string>
#include <utility>
#include "torch/torch.h"

namespace flag_gems::utils {
Expand All @@ -24,3 +27,9 @@ std::tuple<at::Tensor, int64_t, int64_t> permute_reduction_axes_right(
std::tuple<int64_t, int64_t, int64_t> parse_reduction_axes(const at::Tensor &tensor, int reduction_axis);
int cdiv(int a, int b);
} // namespace flag_gems::utils

// Public declaration for reuse (implementation in lib/exponential_.cpp)
namespace flag_gems {
std::pair<uint64_t, uint64_t> philox_backend_seed_offset(int64_t increment,
c10::optional<at::Generator> generator_opt);
}
3 changes: 2 additions & 1 deletion lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ add_library(operators
softmax.cpp
exponential_.cpp
fill.cpp
reshape_and_cache_flash.cpp)
reshape_and_cache_flash.cpp
flash_attn_varlen_func.cpp)
target_include_directories(operators
PUBLIC
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/include>
Expand Down
Loading
Loading