-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RDNA3 support #27
Comments
A CK disscussion has show a branch which has flash-attention kernel impl and already work in ait ROCm/composable_kernel#1032 |
Hi @WilliamGazeley and @sdli1995, I wanted to update you on the attention kernels for the NAVI platform. My colleague @aska-0096 did activate them. However, these Flash-Attention kernels were initially developed for MI devices, which operate on a distinct set of CK kernels. In essence, the issue is that we haven't yet integrated these kernels into our current API. I plan to work on this integration in my spare time. |
Great work on the fork, @howiejayz Will it take too long to port the kernels for gfx1100 and other non-MI architectures? I need the kernels for my project and I'd be willing to help out if I can. |
Thanks for reaching out and offering to help @AlpinDale! I'm currently tied up with a few other projects, so I can't give an exact timeframe for porting the kernels for gfx1100 and other architectures. But I'm definitely planning to tackle this soon. The first step will be creating a new code path for gfx110x, considering their CK kernels are only for forward ops. I'm totally open to suggestions or any help you can provide. It'd be great to have some extra hands on this. Let me know if you're interested! |
I am a complete novice in this field, but a few months ago I managed to make Composable Kernel, Flash Attention and PyTorch work together for my RX 7900 XTX (see here, sort by performance, and look for the first one). Although I was able get that Flash Attention implementation "working" in the end, the generated images were meaningless, and I gave up because I didn't know how to fix it. Here are the relevant branch links, and I hope they can be of some help to you:
I made a grouped fused kernel by Frankensteining the batched fused kernel, which matched the call signatures in this repo at that time. However, that self-made kernel might just be broken. |
Hi @evshiron! First off, I must say I'm seriously impressed by your work! It's quite an achievement, and the resources you've provided are invaluable. I've had the opportunity to build your implementation on gfx1100, and I'm pleased to report that the build was successful. However, I encountered an issue with the unit tests not passing due to incorrect results in the forward pass: assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item() which is likely stemming from incorrect parameter settings in the CK kernels. I guess this should be the reason why the output image become meaningless. Despite this, your work has been immensely helpful! This will massively speed up the navi porting process for the v2 implementation. |
@howiejayz I'm glad that my humble work could be of some help. I am indeed unfamiliar with this field, so I can only leave it to professionals. |
Guys I have added the batched forward(consistent sequence lengths) support for gfx1100, gfx1101, gfx1102 under this branch. Thanks to @aska-0096's CK kernels. The implementation is still under development and there are a lot of things to fine-tune. For now I see the performance is generally better when To install just use I only had the chance to test it on gfx1100 but I expect it works as well for the other two. Let me know if there is any issue! The docker I used to test is |
Using a 7900XTX with Results for `benchmark_flash_attention_forward.py`
Results for `test_flash_attn_wmma_rocm.py`
Full Log: test_flash_attn_wmma_rocm.log Error for `benchmark_flash_attention.py`
|
Some benchmark results. RTX 4090
7900 XTX
|
4090 fp16 accumulate fp16 tensorcore performance is 330T ,while 7900xtx is 120T the better reference nvidia card is rtx3090 |
Thanks for the benchmark data. We are going to launch a new version of Composable Kernel with better flash-attention performance. Adapt the optimization on RDNA3 is in my plan. |
@sdli1995 here's the benchmarks with a 3090:
|
Any updates on this? |
We need official support for flash attention |
trust bro, be patient don't rush them |
I've been using the howiejayz/navi_support branch on here with stable-diffusion-webui for a few weeks now. The implementation is perfect. On an RX 7800 XT, it speeds it/s up from 1.75 it/s to 2 it/s, all while massively decreasing VRAM usage. |
Could you please provide more information about how. Did you just install the branch install it and it worked out of the box or did you have to change code of the webui you are using ? |
@gel-crabs I failed to install flash-attn for Navi. please give more info |
Alright, I'm going to try to give instructions on how I got this to work. If you're on Arch, I have a very amateur PKGBUILD (requiring --skipinteg) that gets it to work. You need to go into the PKGBUILD and replace GPU_ARCHS=gfx1101 with your GPU's architecture and MAX_JOBS to however many CPU cores you have. I can only confirm it will work on gfx11+. The patch just changes the C++ standard from c++20 to c++17 to allow it to build. If you aren't on Arch, you can generally just follow the commands and install the python wheel file afterwards, in your virtualenv if you're using one. You can clone the repo with Now for webui. You will have to use a patch that has been closed since it will be obsolete once AMD finishes xformers support. https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11902.patch That's the link to the raw patch file, use AUTOMATIC1111/stable-diffusion-webui#11902 In the discussion for this patch, I posted a long comment on getting it to work. (AUTOMATIC1111/stable-diffusion-webui#11902 (comment)). The important info from that is the 2 code blocks you need to change manually. After that, add --flash-attn to your command line arguments and it should work. If you get slower or the same speed, flash attention isn't working. You may get a HIP OOM at the end of generation if you're using a higher resolution than usual, as it needs to switch back to SDP at the end due to not supporting a head dim over 128. If you get an error involving a flash-attn so not loading, rebuild the PKGBUILD but change CC=clang and CXX=clang++ to CC=gcc and CXX=g++. |
Switching setup.py to c++17 built successfully on gfx1100 Seems to work in transformers, as a 7b model OOMs @ >8k context using the default attention but doesn't even crack 20 gigs with FA2 enabled. Interestingly I lose about 1 t/s though? I'll have to see if I can monkeypatch it into Diffusers... |
@gel-crabs that and tried FLASH_ATTENTION_INTERNAL_USE_RTN=1 pip install . (I use Debian.) |
Do you mean the unit testing? For that you need to export FLASH_ATTENTION_INTERNAL_UNIT_TEST_MODE=1 and FLASH_ATTENTION_INTERNAL_DETERMINISTIC=1. You should also set your GPU_ARCHS to your GPU architecture (gfx1100, gfx1101, etc.) and try building with GCC and Clang. I can also only guarantee this will work on ROCM 5.7 and up. For anything other than SD webui, you will likely have to create a forward function yourself, as it is a PyTorch extension and isn't integrated into PyTorch yet. The implementation is here, but keep in mind it requires my changes as well: https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11902/files |
I don't have the knowledge to contribute to this issue, but I'm really rooting for this support feature! |
Could you provide me what you found. What‘s the size you can go before it has to fallback to sdp. |
I found larger resolutions benefit more. On SD1.5, 1088x1088 on an RX 7900 XTX went from 1.67 it/s to 3 it/s while 512px was a more modest 16 it/s to 18 it/s. Note: If using a venv I found I had to build the wheel with the venv activated. Otherwise, the library complained about undefined symbols and failed to load. |
I think I explained wrong; it will always fall back to SDP at the very end of generation (after all the steps have finished), resolution doesn't factor into it. What I meant is that the massively decreased VRAM usage will (currently..?) not allow you to use a higher resolution than with regular SDP attention, as the VRAM usage will jump back up at the end. However, it most likely could... if AMD hooked up CK's Navi branch to their new xformers port. ;) |
Also note: the switch back to SDP at the end (I believe only with SDXL) can be prevented by switching from full VAE to TAESD, or (I assume) a tiled VAE implementation. This allows a 1024x1024 image to be upscaled to 2048x2048 with SDXL on an RX 7800 XT with 16GB of VRAM. |
Hi @gel-crabs Could be that I did not build flash-attn wheel correctly. This is what I used to build it:
Do you think this is correct? |
Alright, I think I've got it. This just copy-pastes @Beinsezii's hijack method in with a small fix for the SDP attnblock, so just change from flash-attn back to SDP, then remove the --flash-attn option from modules/cmd_args.py and your webui startup. This eschews sub-quadratic altogether (without a memory spike, somehow), so it shouldn't be producing garbage anymore. Thank you so much, @Beinsezii |
I forgot to mention here but for ComfyUI people in this thread I already made a ComfyUI addon a few weeks ago that's basically just the SDPA hijack plonked into a fake node.
What's actually changed? It looks like my monkey patch is the same.
Subquad should use less memory than SDPA, but if the tile size is configured to be high enough it can end up using about the same. I've never looked at Auto's code so I can't say what goes on there. I'd still recommend SDPA since it's basically the same function and really only the VAEs have memory issues when I tested up to 3840x2160. ComfyUI and Diffusers both have a tiled VAE flag which fixes that more or less flawlessly at a mild speed cost; I'm assuming Auto does too. I've tested the current SDPA hijack on SD 1.5, 2.1, XL, Stable Cascade, and Pixart Alpha to great success so there shouldn't be anything broken as long as Auto's code is good. Even img2img and multi fused Lora seems to behave well. |
In the sdp_attnblock_forward function at the end:
Changed to:
This was in the original flash-attn patch, and I had to add it due to this line in the patch producing a "tuple index out of range" error at the VAE stage:
|
Finally!!! Wow thank you so much guys. @Beinsezii and @gel-crabs It works but please continue to read below. When I replaced the existing module with this latest hijack file from @gel-crabs , it did not work straightaway. Auto1111 continued to produce garbage. Refusing to give up, I uninstalled flash-attn and reinstalled it again. This time I had overridden gfx version as below before reinstalling flash-attn: And let it build wheels for flash-attn. After this auto1111 started working fine with the flash-attn magic. For anyone who wants to try it. Here's the commands that I ran to get it working after replacing sd_hijack_optimizations.py file:
Some quick comparison on how tremendous this is for AMD folks who use auto1111. SD1.5 base model on Flash attention SD1.5 base model on Doggetx (default) How ridiculous this looks when you see the VRAM usage.. LOL My specs: Linux Mint, 7800XT, torch==2.3.0.dev20240314+rocm6.0. |
Just a note that this also works perfect. |
Has anyone had luck with GFX 1030? I think I am out of luck :/ |
RDNA ≤ 2 cards don't have WMMAs so I'm not sure they'll be supported anytime soon. |
@xzuyn I'm having the same error, did you find a fix for it? it's on the backward pass... |
Is this a absolutely required? Could it be forked to use some other implementation that does functionally the same thing? |
You could always use a pure pytorch impl like sub quadratic. Tune the tile sizes to your GPU and it should perform okayish and not OOM on large shapes. |
Any options for something similar, but for LLMs / Llama, rather than Stable Diffusion? |
Best Navi2 option would probably be Exllama2's Q4 context quantization. You still bleed iteration speed but the memory climb is greatly reduced. VLLM has a triton impl that works great but I don't know if it'll compile against Navi 2 considering the dockerfile only targets the Instinct cards and Navi 3 |
Navi 3 / GFX1100 is what we use, and it does work for inference, but not training. |
Navi cards are outright unsuitable for training most FFTs and new models right now. There's barely functional flash forward passes in a few spots but backwards is currently completely unsupported by AMD as far as I know. You can make it work with deepspeed, lora, or other tricks but it's still very far from being efficient. |
I'll add the DaoLab triton impl runs okay-ish for me in my diffusion app. Because it's pure triton I think backwards should work? I'm not sure how useful it would be for LLMs. |
Did you manage to install flash attention from their repo? It errored out for me |
The triton file I linked is usable directly as a standalone file. The repo's HIP kernels only support Instinct GPUs |
Big update for those using the CK-based version for forward inference (i.e. Stable Diffusion): I managed to get head dim = 512 working with Flash Attention, which means that switching to SDP/sub-quadratic for the VAE stage isn't needed anymore. I've tested it for the past 2 weeks; no OOMs, no difference in images. It actually works. It can be installed with:
|
What does "direct output" mean? Does it only work with Stable Diffussion? Will FA work with text generation models? |
What do you mean by direct output? |
Hello, could you tell me how to switching setup.py to c++17 built? Are there any documents? Thank you very much. |
Hi @gel-crabs i tired your version of FA along with the AMD_Go_Fast.py patch for comfyui. In the patch there is some code that references 128 (see below) if query.shape[3] <= 128 and attn_mask is None: Is this suppose to work if i change it to 512 ? Because if i do it triggers a long error, this is the last line. DeviceGroupedQueryAttentionForward_Wmma, 256, 128, 128, 64, 8, 8, 128, 64, 64, 8, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled> AEnableLds: 0, B0EnableLds: 1, B1EnableLds: 1, NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem Not really sure if/how these things work together or if they can with minimal effort. |
Not sure how i fixed this problem, but i no longer get this error with sd1.5 models. |
I couldn't anyone else's flash-attention to compile. Your' s worked like a charm. Nice 0.7it/s boost and lower vram usage. Used in combination with the amd-go-fast script for confyui that omeone else mentioned. |
Great work so far. I'm trying to run vLLM on my 7900XTX cards and was wondering if there were any plans to support RDNA3?
The text was updated successfully, but these errors were encountered: