-
Notifications
You must be signed in to change notification settings - Fork 145
[AdvancedCompiler]flash_varlen_fwd_kernel(cpp wrapper) #973
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -87,4 +87,104 @@ void reshape_and_cache_flash(const at::Tensor &key, | |||||
const std::string &kv_cache_dtype, | ||||||
const std::optional<at::Tensor> &k_scale, | ||||||
const std::optional<at::Tensor> &v_scale); | ||||||
|
||||||
std::tuple<at::Tensor, at::Tensor> flash_attn_varlen_func( | ||||||
const at::Tensor &q, | ||||||
const at::Tensor &k, | ||||||
const at::Tensor &v, | ||||||
int64_t max_seqlen_q, | ||||||
const at::Tensor &cu_seqlens_q, | ||||||
int64_t max_seqlen_k, | ||||||
const std::optional<at::Tensor> &cu_seqlens_k = std::nullopt, | ||||||
const std::optional<at::Tensor> &seqused_k = std::nullopt, | ||||||
const std::optional<at::Tensor> &q_v = std::nullopt, | ||||||
double dropout_p = 0.0, | ||||||
const std::optional<double> &softmax_scale = std::nullopt, | ||||||
bool causal = false, | ||||||
// CHANGE: window_size is now two separate integers | ||||||
int64_t window_size_left = -1, | ||||||
int64_t window_size_right = -1, | ||||||
double softcap = 0.0, | ||||||
const std::optional<at::Tensor> &alibi_slopes = std::nullopt, | ||||||
bool deterministic = false, | ||||||
bool return_attn_probs = false, | ||||||
const std::optional<at::Tensor> &block_table = std::nullopt, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The parameter is named
Suggested change
|
||||||
bool return_softmax_lse = false, | ||||||
const std::optional<at::Tensor> &out = std::nullopt, | ||||||
const std::optional<at::Tensor> &scheduler_metadata = std::nullopt, | ||||||
const std::optional<double> &q_descale = std::nullopt, | ||||||
const std::optional<double> &k_descale = std::nullopt, | ||||||
const std::optional<double> &v_descale = std::nullopt, | ||||||
int64_t num_splits = 0, | ||||||
int64_t fa_version = 2); | ||||||
|
||||||
struct FlashFwdParams { | ||||||
// tensor pointers | ||||||
at::Tensor q; | ||||||
at::Tensor k; | ||||||
at::Tensor v; | ||||||
at::Tensor out; | ||||||
at::Tensor p; | ||||||
at::Tensor lse; | ||||||
// strides | ||||||
int64_t q_row_stride; | ||||||
int64_t k_row_stride; | ||||||
int64_t v_row_stride; | ||||||
int64_t q_head_stride; | ||||||
int64_t k_head_stride; | ||||||
int64_t v_head_stride; | ||||||
int64_t o_row_stride; | ||||||
int64_t o_head_stride; | ||||||
// batch strides | ||||||
int64_t q_batch_stride; | ||||||
int64_t k_batch_stride; | ||||||
int64_t v_batch_stride; | ||||||
int64_t o_batch_stride; | ||||||
// cu_seqlens / seqused_k flags & tensors | ||||||
bool is_cu_seqlens_q; | ||||||
at::Tensor cu_seqlens_q; | ||||||
bool is_cu_seqlens_k; | ||||||
at::Tensor cu_seqlens_k; | ||||||
bool is_seqused_k; | ||||||
at::Tensor seqused_k; | ||||||
// sizes | ||||||
int64_t batch_size; | ||||||
int64_t k_batch_size; | ||||||
int64_t num_heads; | ||||||
int64_t num_heads_k; | ||||||
int64_t h_hk_ratio; | ||||||
int64_t seqlen_q; | ||||||
int64_t seqlen_k; | ||||||
int64_t seqlen_q_rounded; | ||||||
int64_t seqlen_k_rounded; | ||||||
int64_t head_size; | ||||||
int64_t head_size_rounded; | ||||||
// scaling factors | ||||||
bool is_softcap; | ||||||
double softcap; | ||||||
double scale_softmax; | ||||||
double scale_softmax_log2e; | ||||||
// dropout | ||||||
bool is_dropout; | ||||||
double p_dropout; | ||||||
double rp_dropout; | ||||||
int64_t p_dropout_in_uint8_t; | ||||||
at::Tensor philox_args; | ||||||
bool return_softmax; | ||||||
// causal & sliding window attention | ||||||
bool is_causal; | ||||||
bool is_local; | ||||||
int64_t window_size_left; | ||||||
int64_t window_size_right; | ||||||
bool seqlenq_ngroups_swapped; | ||||||
// alibi | ||||||
bool is_alibi; | ||||||
at::Tensor alibi_slopes; | ||||||
int64_t alibi_slopes_batch_stride; | ||||||
// block table params | ||||||
int64_t total_q; | ||||||
at::Tensor page_table; | ||||||
int64_t page_table_batch_stride; | ||||||
int64_t block_size; | ||||||
}; | ||||||
} // namespace flag_gems |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This comment appears to be a leftover from development and does not serve as documentation. It should be removed to keep the header file clean.