Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions _posts/2024-08-07-flexattention.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,12 @@ def sliding_window_causal(b, h, q_idx, kv_idx):
return causal_mask & window_mask

# If you want to be cute...
from torch.nn.attention import or_masks
from torch.nn.attention import and_masks

def sliding_window(b, h, q_idx, kv_idx)
return q_idx - kv_idx <= SLIDING_WINDOW

sliding_window_causal = or_masks(causal_mask, sliding_window)
sliding_window_causal = and_masks(causal_mask, sliding_window)
```

We benchmark it against `F.scaled_dot_product_attention` with a sliding window mask as well as FA2 with a causal mask (as a reference point for performance). Not only are we significantly faster than `F.scaled_dot_product_attention`, we’re *also* significantly faster than FA2 with a causal mask as this mask has significantly more sparsity.
Expand Down Expand Up @@ -479,4 +479,4 @@ We want to highlight some prior work (and people) that have inspired FlexAttenti
- The Jax team's work on SplashAttention
- Philippe Tillet and Keren Zhou for helping us with Triton
- Ali Hassani for discussions on neighborhood attention
- Everybody who's complained about attention kernels not supporting their favorite attention variant :)
- Everybody who's complained about attention kernels not supporting their favorite attention variant :)