diff --git a/_posts/2024-08-07-flexattention.md b/_posts/2024-08-07-flexattention.md index 4c34879d33b6..d601bc085c58 100644 --- a/_posts/2024-08-07-flexattention.md +++ b/_posts/2024-08-07-flexattention.md @@ -218,12 +218,12 @@ def sliding_window_causal(b, h, q_idx, kv_idx): return causal_mask & window_mask # If you want to be cute... -from torch.nn.attention import or_masks +from torch.nn.attention import and_masks def sliding_window(b, h, q_idx, kv_idx) return q_idx - kv_idx <= SLIDING_WINDOW -sliding_window_causal = or_masks(causal_mask, sliding_window) +sliding_window_causal = and_masks(causal_mask, sliding_window) ``` We benchmark it against `F.scaled_dot_product_attention` with a sliding window mask as well as FA2 with a causal mask (as a reference point for performance). Not only are we significantly faster than `F.scaled_dot_product_attention`, we’re *also* significantly faster than FA2 with a causal mask as this mask has significantly more sparsity. @@ -479,4 +479,4 @@ We want to highlight some prior work (and people) that have inspired FlexAttenti - The Jax team's work on SplashAttention - Philippe Tillet and Keren Zhou for helping us with Triton - Ali Hassani for discussions on neighborhood attention -- Everybody who's complained about attention kernels not supporting their favorite attention variant :) \ No newline at end of file +- Everybody who's complained about attention kernels not supporting their favorite attention variant :)