Description
https://github.com/pytorch-labs/segment-anything-fast/ uses custom Triton code to implement a variant of SDPA that supports the kind of additive attention required by the image_encoder.
In a nutshell the code it implements using this custom Triton kernel is
rel_h = rel_h.reshape(B, q_h * q_w, k_h, 1)
rel_w = rel_w.reshape(B, q_h * q_w, 1, k_w)
attn_bias = (rel_h_ + rel_w_).view(q_.size(0), q_.size(1),
rel_h_.size(2), rel_h_.size(3) * rel_w_.size(4))
return torch.nn.functional.scaled_dot_product_attention(q_, k_, v_, attn_mask=attn_bias)
With the release of FlexAttention in PyTorch 2.5(code examples) it should now we possible to express this without the need for custom Triton code.
Not only will FlexAttention be able to support a fused implementations for more input shapes, it is also likely to produce more optimal code and with better hyperparameters. This kind of fused attention caused an end-to-end improvement of about 1.15x on top of a fused SDPA and torch.compile'd (with CUDA graphs) baselined.
The task:
Copy over the relevant files from segment-anything-fast into torchao's model folder and follow the readme to rerun if needed.
Write a FlexAttention version of flash_4 and measure difference in performance. If it helps, we can immediately land it in torchao, but at a minimum it could influence FlexAttention development.