Open
Description
From my understanding, flex attention (using block_mask
) gets faster when the number of empty blocks is larger. If the inputs (Q, K, V) do not represent sequences, but graphs with local connectivity (e.g. pixels in an image) the ordering of the elements has a huge impact on the number of empty blocks.
It would be very useful to add helpers to find optimal, or simply better, orderings given a mask. For example, for images, it is likely better to order the pixels by small patch (close to the attention window size), rather than the standard row-by-row order.
Note that this is related to the minimum degree algorithm.
Metadata
Metadata
Assignees
Labels
No labels