-
Notifications
You must be signed in to change notification settings - Fork 198
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace flash_4 with FlexAttention #639
Comments
I would like to give this a shot. Could you help me clarify something? Is the goal to make a fork of segment-anything-fast that uses Flex Attention, and test that in What I could do is make a fork of segment-anything-fast that uses Flex Attention and use that as an alternative pip install to Let me know if this makes any sense, or if you meant something else. |
@tobiasvanderwerff - Yes, we could also get started with an experimental PR against https://github.com/pytorch-labs/segment-anything-fast . Eventually it could be convenient to be able to vendor the changes in SAM-fast and make them more easily accessible via torchao packaging and distribution. What do you think about this? |
@cpuhrsch that sounds like a plan. Let me try to get started on this in the next few days. I already tried to run the SAM benchmark today to get started but realized that my current GPU (NVIDIA T4) does not support Flash Attention (since it requires compute capability >=sm_80, e.g. an A100). However, I intend to get access to a cloud A100 GPU instance in the next few days. If getting access to a better GPU doesn't work out, I don't think I'll be able to work on this, and I'll let you know in that case. |
@cpuhrsch as discussed, I've created a fork of the segment-anything-fast repo that uses Flex Attention instead of the custom Triton kernel. I've also added a test to check for correctness. You can see the changes here. I'm posting benchmark results from As a side note, Flex Attention only accepts embedding sizes that are powers of two, so I had to add padding to make it work. It's possible that the padding leads to the negative effect in performance, although the Triton kernel seems to do the same thing. Torch version: Baseline results (using Triton kernel):
Flex Attention results (I omitted the last two rows because running the benchmark was taking a long time):
|
Hm, very interesting. Thanks for doing this work. Do you mind attaching GPU traces for say the first setup both with and without flexattention? You can gather traces using https://github.com/pytorch-labs/segment-anything-fast/tree/e6aadeb86f3ae1f58c3f98e2a91e251716e0f2aa/experiments#kernel-traces . Just ensure that |
Using the GPU traces it is also possible to annotate (using https://pytorch.org/docs/main/generated/torch.autograd.profiler.record_function.html#record-function and https://pytorch.org/docs/main/generated/torch.cuda.synchronize.html#torch-cuda-synchronize ) the section that was changed and look at the GPU kernel difference in runtime only. This way we can double check the slowdown is precisely due to this change. I'd create two versions of these traces, one with annotation and sync and one without. So that means 4 traces in total a) Baseline without annotate |
Tracing results indicate that in the Flex Attention version, a lot of time is spent on a padding kernel ( The trace shows that the Flex Attention impl. spends 2 seconds in the image encoder, whereas the baseline spends only 1.35 seconds. So it definitely looks like quite a slowdown in the part of the code where SDPA is used. Padding does not seem to take nearly as much time in the baseline (in the trace, the largest purple blocks under the So it seems that the padding is a large source of the slowdown. As I mentioned earlier, the Triton kernel does the same padding, but they somehow have made it more efficient. At the top of the function, it says:
So it seems like they somehow manage to make the padding more efficient by fusing it into earlier operations. I'm currently trying to figure out if this can also be done for the Flex Attention kernel, but it's not obvious to me how. (NB: I also tried running the tracing with the annotations, as you suggested @cpuhrsch, but this did not seem to show up in the trace output - perhaps because of torch.compile?) |
@cpuhrsch - Hm, the way you're using FlexAttention it should also be a composite (as in Since this is needed specifically for vit_h, does it mean for vit_b the gap narrows or even with FlexAttention it's faster? |
Baseline:
Flex Attention:
I looked at the profile traces but it is difficult to extract any useful information. Most of the kernels in the Flex Attention version have indiscriminate names like |
I may have found a clue as to where the performance bottleneck lies. Replacing this line in the attn_bias = self.rel_h[batch, head, q_idx, h_idx] + self.rel_w[batch, head, q_idx, w_idx] with this: attn_bias = h_idx + w_idx leads to a massive speedup (38 img/s -> 97 img/s). So it seems that the indexing into |
Unfortunately, using |
Great, thank you for the investigation @tobiasvanderwerff ! |
@tobiasvanderwerff - For what it's worth, indexing into the |
@cpuhrsch an update: I've tried the fix pushed by @Chillee, but unfortunately I still get an error (see output below). It looks like the minified code sample I referred to in the issue does not quite transfer to the more complicated setup of the SAM-fast model. I'm not really sure how to resolve this right now, and unfortunately it is not very feasible for me to keep using an A100 for testing due to expenses (sorry). So the best strategy may be to put this on hold right now and perhaps wait until FlexAttention manages this issue at some point.
|
@tobiasvanderwerff - Thank you for testing this. I'll update pytorch-labs/attention-gym#45 as well. At least with the most recent fix we're one step closer. |
https://github.com/pytorch-labs/segment-anything-fast/ uses custom Triton code to implement a variant of SDPA that supports the kind of additive attention required by the image_encoder.
In a nutshell the code it implements using this custom Triton kernel is
With the release of FlexAttention in PyTorch 2.5(code examples) it should now we possible to express this without the need for custom Triton code.
Not only will FlexAttention be able to support a fused implementations for more input shapes, it is also likely to produce more optimal code and with better hyperparameters. This kind of fused attention caused an end-to-end improvement of about 1.15x on top of a fused SDPA and torch.compile'd (with CUDA graphs) baselined.
The task:
Copy over the relevant files from segment-anything-fast into torchao's model folder and follow the readme to rerun if needed.
Write a FlexAttention version of flash_4 and measure difference in performance. If it helps, we can immediately land it in torchao, but at a minimum it could influence FlexAttention development.
The text was updated successfully, but these errors were encountered: