Skip to content

Conversation

AdvancedCompiler
Copy link
Contributor

PR Category

Operator

Type of Change

New Feature

Description

CPP wrapper packaging for flash_varlen_fwd_kernel

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @AdvancedCompiler, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant new feature by integrating a C++ wrapper for the flash_varlen_fwd_kernel, a highly optimized Flash Attention implementation designed for variable-length sequences. This enhancement provides a robust and efficient attention mechanism, crucial for modern deep learning models, and is accompanied by thorough testing to ensure its functional correctness and performance parity with reference implementations.

Highlights

  • New Flash Attention Operator: Introduces a C++ wrapper for flash_varlen_fwd_kernel, a Flash Attention implementation for variable-length sequences, enhancing efficiency for deep learning models.
  • Comprehensive Testing: Adds a new C++ test suite (test_triton_flash_attn_varlen.cpp) with reference implementations to validate the new operator across diverse configurations, including different head counts, head sizes, data types, ALiBi, softcap, and block sizes.
  • Parameter Struct for Forward Pass: Defines a FlashFwdParams struct in operators.h to centralize and manage parameters for the Flash Attention forward pass, supporting features like paged KV cache, softcap, dropout, causal attention, sliding window, and ALiBi.
  • Triton JIT Integration: The new flash_attn_varlen_func.cpp implements the operator by preparing parameters and invoking the flash_varlen_fwd_kernel via Triton JIT, leveraging optimized kernels for GPU execution.
  • PyTorch ATen Registration: The flash_attn_varlen_func is registered with the PyTorch ATen library in cstub.cpp, making it seamlessly callable from Python within the PyTorch ecosystem.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a C++ wrapper for the flash_varlen_fwd_kernel Triton kernel, enabling paged attention functionality. The changes are comprehensive, including the core operator implementation, PyTorch bindings, and thorough unit tests. The overall implementation is solid, but I've identified a few areas for improvement related to code clarity, consistency, and robustness. My review includes suggestions to remove unused code, clarify variable names, ensure variables are always initialized, and improve naming consistency. I've also noted the high tolerance in tests, which might be worth investigating to ensure numerical correctness.

at::Tensor lse;
at::Tensor philox_args;
at::Tensor p;
at::Tensor unused; // optional, may remain undefined when not used
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The unused tensor is returned as part of the tuple from mha_varlan_fwd_internal, but it is only initialized if seqlenq_ngroups_swapped is true. If it's false, an uninitialized tensor will be returned, which can lead to undefined behavior. It should be initialized to a default-constructed (undefined) tensor.

  at::Tensor unused = at::Tensor();  // optional, may remain undefined when not used

sliding_window_opt,
soft_cap_opt);

EXPECT_TRUE(torch::allclose(op_output, ref_output, /*rtol=*/1e-2, /*atol=*/2e-2));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The relative and absolute tolerances (1e-2) used for comparing the operator output with the reference implementation are quite high. This might hide potential numerical discrepancies. Could you investigate if these tolerances can be tightened? If these values are necessary due to the nature of the computation (e.g., using half-precision floats), it would be beneficial to add a comment explaining why such high tolerances are required.

double dropout_p = 0.0,
const std::optional<double> &softmax_scale = std::nullopt,
bool causal = false,
// CHANGE: window_size is now two separate integers
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment appears to be a leftover from development and does not serve as documentation. It should be removed to keep the header file clean.

Suggested change
// CHANGE: window_size is now two separate integers
// window_size is now two separate integers

const std::optional<at::Tensor> &alibi_slopes = std::nullopt,
bool deterministic = false,
bool return_attn_probs = false,
const std::optional<at::Tensor> &block_table = std::nullopt,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter is named block_table here and in the PyTorch library definition, but it's referred to as page_table in the internal implementation (mha_varlan_fwd_internal and FlashFwdParams). For consistency and clarity, it would be better to use the same name throughout. page_table is arguably more descriptive in the context of paged attention. I recommend renaming block_table to page_table in the public API. This change should be propagated to cstub.cpp, lib/flash_attn_varlen_func.cpp, and ctests/test_triton_flash_attn_varlen.cpp for consistency.

Suggested change
const std::optional<at::Tensor> &block_table = std::nullopt,
const std::optional<at::Tensor> &page_table = std::nullopt,

Comment on lines +13 to +33
mha_varlan_fwd_internal(const at::Tensor& q,
const at::Tensor& k,
const at::Tensor& v,
const at::Tensor& out,
const at::Tensor& cu_seqlens_q,
const at::Tensor& cu_seqlens_k,
const at::Tensor& seqused_k,
const at::Tensor& leftpad_k,
const at::Tensor& page_table,
const at::Tensor& alibi_slopes,
int64_t max_seqlen_q,
int64_t max_seqlen_k,
double p_dropout,
double softmax_scale,
bool zero_tensors,
bool is_causal,
int64_t window_size_left,
int64_t window_size_right,
double softcap,
bool return_softmax,
const at::Tensor& gen) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameters leftpad_k and gen in mha_varlan_fwd_internal are not used. leftpad_k is only checked to be undefined, and gen is ignored as philox_backend_seed_offset is called with c10::nullopt. These parameters should be removed from the function signature and call sites to simplify the code.

mha_varlan_fwd_internal(const at::Tensor& q,
                        const at::Tensor& k,
                        const at::Tensor& v,
                        const at::Tensor& out,
                        const at::Tensor& cu_seqlens_q,
                        const at::Tensor& cu_seqlens_k,
                        const at::Tensor& seqused_k,
                        const at::Tensor& page_table,
                        const at::Tensor& alibi_slopes,
                        int64_t max_seqlen_q,
                        int64_t max_seqlen_k,
                        double p_dropout,
                        double softmax_scale,
                        bool zero_tensors,
                        bool is_causal,
                        int64_t window_size_left,
                        int64_t window_size_right,
                        double softcap,
                        bool return_softmax) {

Comment on lines +234 to +236
p_dropout = 1.0 - p_dropout;
int64_t p_dropout_in_uint8_t = static_cast<int64_t>(std::floor(p_dropout * 255.0));
double rp_dropout = 1.0 / p_dropout;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The variable p_dropout is reassigned to represent the keep probability (1.0 - p_dropout). This is confusing and makes the code harder to follow. It would be clearer to introduce a new variable, for example p_keep, for the keep probability, and use it consistently. You will also need to update params.p_dropout to be assigned from this new variable.

    double p_keep = 1.0 - p_dropout;
    int64_t p_dropout_in_uint8_t = static_cast<int64_t>(std::floor(p_keep * 255.0));
    double rp_dropout = 1.0 / p_keep;

@AdvancedCompiler AdvancedCompiler changed the title [AdvancedCompiler]New Feature(cpp wrapper) [AdvancedCompiler]flash_varlen_fwd_kernel(cpp wrapper) Sep 19, 2025
@kiddyjinjin
Copy link
Collaborator

Please update flash_att_varlen_func in flag_gems/ops/attention.py to redirect the operation to the C++ wrapper when use_c_extension is set to true.
You can refer to the usage in the reshape_and_cache_flash function in flag_gems/fused/reshape_and_flash.py.
Also, please add some tests in modules_test to verify that it works correctly with vLLM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants