-
Notifications
You must be signed in to change notification settings - Fork 144
[AdvancedCompiler]flash_varlen_fwd_kernel(cpp wrapper) #973
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
[AdvancedCompiler]flash_varlen_fwd_kernel(cpp wrapper) #973
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @AdvancedCompiler, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces a significant new feature by integrating a C++ wrapper for the flash_varlen_fwd_kernel
, a highly optimized Flash Attention implementation designed for variable-length sequences. This enhancement provides a robust and efficient attention mechanism, crucial for modern deep learning models, and is accompanied by thorough testing to ensure its functional correctness and performance parity with reference implementations.
Highlights
- New Flash Attention Operator: Introduces a C++ wrapper for
flash_varlen_fwd_kernel
, a Flash Attention implementation for variable-length sequences, enhancing efficiency for deep learning models. - Comprehensive Testing: Adds a new C++ test suite (
test_triton_flash_attn_varlen.cpp
) with reference implementations to validate the new operator across diverse configurations, including different head counts, head sizes, data types, ALiBi, softcap, and block sizes. - Parameter Struct for Forward Pass: Defines a
FlashFwdParams
struct inoperators.h
to centralize and manage parameters for the Flash Attention forward pass, supporting features like paged KV cache, softcap, dropout, causal attention, sliding window, and ALiBi. - Triton JIT Integration: The new
flash_attn_varlen_func.cpp
implements the operator by preparing parameters and invoking theflash_varlen_fwd_kernel
via Triton JIT, leveraging optimized kernels for GPU execution. - PyTorch ATen Registration: The
flash_attn_varlen_func
is registered with the PyTorch ATen library incstub.cpp
, making it seamlessly callable from Python within the PyTorch ecosystem.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point by creating a comment using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands on the current page.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in pull request comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a C++ wrapper for the flash_varlen_fwd_kernel
Triton kernel, enabling paged attention functionality. The changes are comprehensive, including the core operator implementation, PyTorch bindings, and thorough unit tests. The overall implementation is solid, but I've identified a few areas for improvement related to code clarity, consistency, and robustness. My review includes suggestions to remove unused code, clarify variable names, ensure variables are always initialized, and improve naming consistency. I've also noted the high tolerance in tests, which might be worth investigating to ensure numerical correctness.
at::Tensor lse; | ||
at::Tensor philox_args; | ||
at::Tensor p; | ||
at::Tensor unused; // optional, may remain undefined when not used |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unused
tensor is returned as part of the tuple from mha_varlan_fwd_internal
, but it is only initialized if seqlenq_ngroups_swapped
is true. If it's false, an uninitialized tensor will be returned, which can lead to undefined behavior. It should be initialized to a default-constructed (undefined) tensor.
at::Tensor unused = at::Tensor(); // optional, may remain undefined when not used
sliding_window_opt, | ||
soft_cap_opt); | ||
|
||
EXPECT_TRUE(torch::allclose(op_output, ref_output, /*rtol=*/1e-2, /*atol=*/2e-2)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The relative and absolute tolerances (1e-2
) used for comparing the operator output with the reference implementation are quite high. This might hide potential numerical discrepancies. Could you investigate if these tolerances can be tightened? If these values are necessary due to the nature of the computation (e.g., using half-precision floats), it would be beneficial to add a comment explaining why such high tolerances are required.
double dropout_p = 0.0, | ||
const std::optional<double> &softmax_scale = std::nullopt, | ||
bool causal = false, | ||
// CHANGE: window_size is now two separate integers |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const std::optional<at::Tensor> &alibi_slopes = std::nullopt, | ||
bool deterministic = false, | ||
bool return_attn_probs = false, | ||
const std::optional<at::Tensor> &block_table = std::nullopt, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameter is named block_table
here and in the PyTorch library definition, but it's referred to as page_table
in the internal implementation (mha_varlan_fwd_internal
and FlashFwdParams
). For consistency and clarity, it would be better to use the same name throughout. page_table
is arguably more descriptive in the context of paged attention. I recommend renaming block_table
to page_table
in the public API. This change should be propagated to cstub.cpp
, lib/flash_attn_varlen_func.cpp
, and ctests/test_triton_flash_attn_varlen.cpp
for consistency.
const std::optional<at::Tensor> &block_table = std::nullopt, | |
const std::optional<at::Tensor> &page_table = std::nullopt, |
mha_varlan_fwd_internal(const at::Tensor& q, | ||
const at::Tensor& k, | ||
const at::Tensor& v, | ||
const at::Tensor& out, | ||
const at::Tensor& cu_seqlens_q, | ||
const at::Tensor& cu_seqlens_k, | ||
const at::Tensor& seqused_k, | ||
const at::Tensor& leftpad_k, | ||
const at::Tensor& page_table, | ||
const at::Tensor& alibi_slopes, | ||
int64_t max_seqlen_q, | ||
int64_t max_seqlen_k, | ||
double p_dropout, | ||
double softmax_scale, | ||
bool zero_tensors, | ||
bool is_causal, | ||
int64_t window_size_left, | ||
int64_t window_size_right, | ||
double softcap, | ||
bool return_softmax, | ||
const at::Tensor& gen) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameters leftpad_k
and gen
in mha_varlan_fwd_internal
are not used. leftpad_k
is only checked to be undefined, and gen
is ignored as philox_backend_seed_offset
is called with c10::nullopt
. These parameters should be removed from the function signature and call sites to simplify the code.
mha_varlan_fwd_internal(const at::Tensor& q,
const at::Tensor& k,
const at::Tensor& v,
const at::Tensor& out,
const at::Tensor& cu_seqlens_q,
const at::Tensor& cu_seqlens_k,
const at::Tensor& seqused_k,
const at::Tensor& page_table,
const at::Tensor& alibi_slopes,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
double p_dropout,
double softmax_scale,
bool zero_tensors,
bool is_causal,
int64_t window_size_left,
int64_t window_size_right,
double softcap,
bool return_softmax) {
p_dropout = 1.0 - p_dropout; | ||
int64_t p_dropout_in_uint8_t = static_cast<int64_t>(std::floor(p_dropout * 255.0)); | ||
double rp_dropout = 1.0 / p_dropout; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable p_dropout
is reassigned to represent the keep probability (1.0 - p_dropout
). This is confusing and makes the code harder to follow. It would be clearer to introduce a new variable, for example p_keep
, for the keep probability, and use it consistently. You will also need to update params.p_dropout
to be assigned from this new variable.
double p_keep = 1.0 - p_dropout;
int64_t p_dropout_in_uint8_t = static_cast<int64_t>(std::floor(p_keep * 255.0));
double rp_dropout = 1.0 / p_keep;
Please update |
PR Category
Operator
Type of Change
New Feature
Description
CPP wrapper packaging for flash_varlen_fwd_kernel
Issue
Progress
Performance