From a462c36a4b247076f171883d1517debb38780d10 Mon Sep 17 00:00:00 2001 From: Horace He Date: Mon, 12 Aug 2024 22:45:33 -0700 Subject: [PATCH] Update 2024-08-07-flexattention.md --- _posts/2024-08-07-flexattention.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/_posts/2024-08-07-flexattention.md b/_posts/2024-08-07-flexattention.md index 4c34879d33b6..d601bc085c58 100644 --- a/_posts/2024-08-07-flexattention.md +++ b/_posts/2024-08-07-flexattention.md @@ -218,12 +218,12 @@ def sliding_window_causal(b, h, q_idx, kv_idx): return causal_mask & window_mask # If you want to be cute... -from torch.nn.attention import or_masks +from torch.nn.attention import and_masks def sliding_window(b, h, q_idx, kv_idx) return q_idx - kv_idx <= SLIDING_WINDOW -sliding_window_causal = or_masks(causal_mask, sliding_window) +sliding_window_causal = and_masks(causal_mask, sliding_window) ``` We benchmark it against `F.scaled_dot_product_attention` with a sliding window mask as well as FA2 with a causal mask (as a reference point for performance). Not only are we significantly faster than `F.scaled_dot_product_attention`, we’re *also* significantly faster than FA2 with a causal mask as this mask has significantly more sparsity. @@ -479,4 +479,4 @@ We want to highlight some prior work (and people) that have inspired FlexAttenti - The Jax team's work on SplashAttention - Philippe Tillet and Keren Zhou for helping us with Triton - Ali Hassani for discussions on neighborhood attention -- Everybody who's complained about attention kernels not supporting their favorite attention variant :) \ No newline at end of file +- Everybody who's complained about attention kernels not supporting their favorite attention variant :)