Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RDNA3 support #27

Open
WilliamGazeley opened this issue Dec 5, 2023 · 76 comments
Open

RDNA3 support #27

WilliamGazeley opened this issue Dec 5, 2023 · 76 comments
Labels
navi hardware

Comments

@WilliamGazeley
Copy link

Great work so far. I'm trying to run vLLM on my 7900XTX cards and was wondering if there were any plans to support RDNA3?

@sdli1995
Copy link

sdli1995 commented Dec 6, 2023

A CK disscussion has show a branch which has flash-attention kernel impl and already work in ait ROCm/composable_kernel#1032
are there any barrier on RNDA3 support ?

@dejay-vu
Copy link

dejay-vu commented Dec 6, 2023

Hi @WilliamGazeley and @sdli1995, I wanted to update you on the attention kernels for the NAVI platform. My colleague @aska-0096 did activate them. However, these Flash-Attention kernels were initially developed for MI devices, which operate on a distinct set of CK kernels.

In essence, the issue is that we haven't yet integrated these kernels into our current API. I plan to work on this integration in my spare time.

@AlpinDale
Copy link

Great work on the fork, @howiejayz

Will it take too long to port the kernels for gfx1100 and other non-MI architectures? I need the kernels for my project and I'd be willing to help out if I can.

@dejay-vu
Copy link

dejay-vu commented Dec 7, 2023

Thanks for reaching out and offering to help @AlpinDale! I'm currently tied up with a few other projects, so I can't give an exact timeframe for porting the kernels for gfx1100 and other architectures. But I'm definitely planning to tackle this soon. The first step will be creating a new code path for gfx110x, considering their CK kernels are only for forward ops.

I'm totally open to suggestions or any help you can provide. It'd be great to have some extra hands on this. Let me know if you're interested!

@evshiron
Copy link

evshiron commented Dec 7, 2023

I am a complete novice in this field, but a few months ago I managed to make Composable Kernel, Flash Attention and PyTorch work together for my RX 7900 XTX (see here, sort by performance, and look for the first one). Although I was able get that Flash Attention implementation "working" in the end, the generated images were meaningless, and I gave up because I didn't know how to fix it. Here are the relevant branch links, and I hope they can be of some help to you:

I made a grouped fused kernel by Frankensteining the batched fused kernel, which matched the call signatures in this repo at that time. However, that self-made kernel might just be broken.

@dejay-vu
Copy link

dejay-vu commented Dec 7, 2023

Hi @evshiron! First off, I must say I'm seriously impressed by your work! It's quite an achievement, and the resources you've provided are invaluable.

I've had the opportunity to build your implementation on gfx1100, and I'm pleased to report that the build was successful. However, I encountered an issue with the unit tests not passing due to incorrect results in the forward pass:

assert (output - output_ref).abs().max().item() <= 2 * (output_pt - output_ref).abs().max().item()
AssertionError: assert 0.109619140625 <= (2 * 0.000244140625)

which is likely stemming from incorrect parameter settings in the CK kernels. I guess this should be the reason why the output image become meaningless.

Despite this, your work has been immensely helpful! This will massively speed up the navi porting process for the v2 implementation.

@evshiron
Copy link

evshiron commented Dec 7, 2023

@howiejayz

I'm glad that my humble work could be of some help. I am indeed unfamiliar with this field, so I can only leave it to professionals.
Furthermore, as you can see, even though I managed to compile it, the improvement in the benchmark is quite limited (I didn't use the specific commit showed here). I hope it's just an issue with my implementation and I look forward to better performance in future implementations.

@dejay-vu
Copy link

dejay-vu commented Dec 9, 2023

Guys I have added the batched forward(consistent sequence lengths) support for gfx1100, gfx1101, gfx1102 under this branch. Thanks to @aska-0096's CK kernels. The implementation is still under development and there are a lot of things to fine-tune. For now I see the performance is generally better when head dim = 64

To install just use pip install .

I only had the chance to test it on gfx1100 but I expect it works as well for the other two. Let me know if there is any issue! The docker I used to test is rocm/pytorch:latest where torch==2.1.0

@xzuyn
Copy link

xzuyn commented Dec 10, 2023

under this branch.

benchmark_flash_attention_forward.py works, but benchmark_flash_attention.py doesn't. Forward speeds look pretty nice.

Using a 7900XTX with torch 2.2.0.dev20231209+rocm5.7.

Results for `benchmark_flash_attention_forward.py`
### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 38.52 TFLOPs/s
Pytorch fwd: 13.37 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 39.33 TFLOPs/s
Pytorch fwd: 14.22 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 41.23 TFLOPs/s
Pytorch fwd: 16.04 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 42.06 TFLOPs/s
Pytorch fwd: 18.02 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 42.02 TFLOPs/s
Pytorch fwd: 19.16 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 37.27 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=128, batch_size=32, seqlen=512 ###
Flash2 fwd: 28.27 TFLOPs/s
Pytorch fwd: 20.43 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=128, batch_size=16, seqlen=1024 ###
Flash2 fwd: 29.38 TFLOPs/s
Pytorch fwd: 21.05 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=128, batch_size=8, seqlen=2048 ###
Flash2 fwd: 30.49 TFLOPs/s
Pytorch fwd: 25.23 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=128, batch_size=4, seqlen=4096 ###
Flash2 fwd: 31.00 TFLOPs/s
Pytorch fwd: 26.99 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=128, batch_size=2, seqlen=8192 ###
Flash2 fwd: 27.50 TFLOPs/s
Pytorch fwd: 28.47 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=False, headdim=128, batch_size=1, seqlen=16384 ###
Flash2 fwd: 20.67 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 24.02 TFLOPs/s
Pytorch fwd: 5.07 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 29.08 TFLOPs/s
Pytorch fwd: 5.48 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 33.49 TFLOPs/s
Pytorch fwd: 5.84 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 36.44 TFLOPs/s
Pytorch fwd: 6.21 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 38.54 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 39.70 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=128, batch_size=32, seqlen=512 ###
Flash2 fwd: 17.89 TFLOPs/s
Pytorch fwd: 8.42 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=128, batch_size=16, seqlen=1024 ###
Flash2 fwd: 21.69 TFLOPs/s
Pytorch fwd: 8.68 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=128, batch_size=8, seqlen=2048 ###
Flash2 fwd: 25.64 TFLOPs/s
Pytorch fwd: 9.78 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=128, batch_size=4, seqlen=4096 ###
Flash2 fwd: 27.06 TFLOPs/s
Pytorch fwd: 10.01 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=128, batch_size=2, seqlen=8192 ###
Flash2 fwd: 27.43 TFLOPs/s
Pytorch fwd: 9.87 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
### causal=True, headdim=128, batch_size=1, seqlen=16384 ###
Flash2 fwd: 24.68 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s
Triton fwd: 0.00 TFLOPs/s
Results for `test_flash_attn_wmma_rocm.py`

=============== 125 failed, 2148 passed, 4606 skipped in 46.01s ================

Full Log: test_flash_attn_wmma_rocm.log

Error for `benchmark_flash_attention.py`
> python benchmarks/benchmark_flash_attention.py

Traceback (most recent call last):
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/flash-attention/benchmarks/benchmark_flash_attention.py", line 97, in <module>
    f, b = time_fwd_bwd(
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/flash-attention/benchmarks/benchmark_flash_attention.py", line 66, in time_fwd_bwd
    time_f, time_b = benchmark_fwd_bwd(func, *args, **kwargs)
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/utils/benchmark.py", line 99, in benchmark_fwd_bwd
    benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/utils/benchmark.py", line 53, in benchmark_backward
    m = t.timeit(repeats)
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py", line 274, in timeit
    self._timeit(number=max(int(number // 100), 2))
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/utils/benchmark/utils/timer.py", line 264, in _timeit
    return max(self._timer.timeit(number), 1e-9)
  File "/usr/lib/python3.10/timeit.py", line 178, in timeit
    timing = self.inner(it, self.timer)
  File "<timeit-src>", line 6, in inner
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/utils/benchmark.py", line 46, in f
    y.backward(grad, retain_graph=True)
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/_tensor.py", line 503, in backward
    torch.autograd.backward(
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/torch/autograd/function.py", line 289, in apply
    return user_fn(self, *args)
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 109, in backward
    _flash_attn_backward(
  File "/home/USER/clones/LLaMA-Efficient-Tuning/venv/lib/python3.10/site-packages/flash_attn/flash_attn_interface.py", line 66, in _flash_attn_backward
    dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
TypeError: bwd(): incompatible function arguments. The following argument types are supported:
    1. () -> None

@AlpinDale
Copy link

Some benchmark results.

RTX 4090

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 150.90 TFLOPs/s,
Pytorch fwd: 20.23 TFLOPs/s,
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 154.49 TFLOPs/s,
Pytorch fwd: 23.69 TFLOPs/s,
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 171.80 TFLOPs/s,
Pytorch fwd: 26.21 TFLOPs/s,
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 172.81 TFLOPs/s,
Pytorch fwd: 27.89 TFLOPs/s,
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 172.96 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 173.04 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,

7900 XTX

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 42.98 TFLOPs/s
Pytorch fwd: 14.38 TFLOPs/s
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 44.31 TFLOPs/s
Pytorch fwd: 14.83 TFLOPs/s
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 48.25 TFLOPs/s
Pytorch fwd: 17.32 TFLOPs/s
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 47.55 TFLOPs/s
Pytorch fwd: 19.40 TFLOPs/s
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 38.40 TFLOPs/s
Pytorch fwd: 20.19 TFLOPs/s
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 41.01 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s

@sdli1995
Copy link

Some benchmark results.

RTX 4090

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 150.90 TFLOPs/s,
Pytorch fwd: 20.23 TFLOPs/s,
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 154.49 TFLOPs/s,
Pytorch fwd: 23.69 TFLOPs/s,
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 171.80 TFLOPs/s,
Pytorch fwd: 26.21 TFLOPs/s,
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 172.81 TFLOPs/s,
Pytorch fwd: 27.89 TFLOPs/s,
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 172.96 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 173.04 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,

7900 XTX

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 42.98 TFLOPs/s
Pytorch fwd: 14.38 TFLOPs/s
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 44.31 TFLOPs/s
Pytorch fwd: 14.83 TFLOPs/s
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 48.25 TFLOPs/s
Pytorch fwd: 17.32 TFLOPs/s
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 47.55 TFLOPs/s
Pytorch fwd: 19.40 TFLOPs/s
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 38.40 TFLOPs/s
Pytorch fwd: 20.19 TFLOPs/s
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 41.01 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s

4090 fp16 accumulate fp16 tensorcore performance is 330T ,while 7900xtx is 120T the better reference nvidia card is rtx3090

@aska-0096
Copy link

Some benchmark results.
RTX 4090

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 150.90 TFLOPs/s,
Pytorch fwd: 20.23 TFLOPs/s,
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 154.49 TFLOPs/s,
Pytorch fwd: 23.69 TFLOPs/s,
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 171.80 TFLOPs/s,
Pytorch fwd: 26.21 TFLOPs/s,
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 172.81 TFLOPs/s,
Pytorch fwd: 27.89 TFLOPs/s,
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 172.96 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 173.04 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,

7900 XTX

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 42.98 TFLOPs/s
Pytorch fwd: 14.38 TFLOPs/s
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 44.31 TFLOPs/s
Pytorch fwd: 14.83 TFLOPs/s
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 48.25 TFLOPs/s
Pytorch fwd: 17.32 TFLOPs/s
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 47.55 TFLOPs/s
Pytorch fwd: 19.40 TFLOPs/s
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 38.40 TFLOPs/s
Pytorch fwd: 20.19 TFLOPs/s
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 41.01 TFLOPs/s
Pytorch fwd: 0.00 TFLOPs/s

4090 fp16 accumulate fp16 tensorcore performance is 330T ,while 7900xtx is 120T the better reference nvidia card is rtx3090

Thanks for the benchmark data. We are going to launch a new version of Composable Kernel with better flash-attention performance. Adapt the optimization on RDNA3 is in my plan.

@AlpinDale
Copy link

@sdli1995 here's the benchmarks with a 3090:

### causal=False, headdim=64, batch_size=32, seqlen=512 ###
Flash2 fwd: 65.38 TFLOPs/s,
Pytorch fwd: 18.38 TFLOPs/s,
### causal=False, headdim=64, batch_size=16, seqlen=1024 ###
Flash2 fwd: 72.94 TFLOPs/s,
Pytorch fwd: 21.69 TFLOPs/s,
### causal=False, headdim=64, batch_size=8, seqlen=2048 ###
Flash2 fwd: 74.11 TFLOPs/s,
Pytorch fwd: 18.92 TFLOPs/s,
### causal=False, headdim=64, batch_size=4, seqlen=4096 ###
Flash2 fwd: 74.98 TFLOPs/s,
Pytorch fwd: 22.27 TFLOPs/s,
### causal=False, headdim=64, batch_size=2, seqlen=8192 ###
Flash2 fwd: 75.06 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,
### causal=False, headdim=64, batch_size=1, seqlen=16384 ###
Flash2 fwd: 75.12 TFLOPs/s,
Pytorch fwd: 0.00 TFLOPs/s,

@AlpinDale
Copy link

Any updates on this?

@Wintoplay
Copy link

We need official support for flash attention

@ewof
Copy link

ewof commented Dec 24, 2023

trust bro, be patient don't rush them

@gel-crabs
Copy link

I've been using the howiejayz/navi_support branch on here with stable-diffusion-webui for a few weeks now. The implementation is perfect.

On an RX 7800 XT, it speeds it/s up from 1.75 it/s to 2 it/s, all while massively decreasing VRAM usage.

@Kademo15
Copy link

I've been using the howiejayz/navi_support branch on here with stable-diffusion-webui for a few weeks now. The implementation is perfect.

On an RX 7800 XT, it speeds it/s up from 1.75 it/s to 2 it/s, all while massively decreasing VRAM usage.

Could you please provide more information about how. Did you just install the branch install it and it worked out of the box or did you have to change code of the webui you are using ?

@Wintoplay
Copy link

@gel-crabs I failed to install flash-attn for Navi. please give more info

@gel-crabs
Copy link

gel-crabs commented Jan 4, 2024

@Kademo15 @Wintoplay

Alright, I'm going to try to give instructions on how I got this to work. If you're on Arch, I have a very amateur PKGBUILD (requiring --skipinteg) that gets it to work. You need to go into the PKGBUILD and replace GPU_ARCHS=gfx1101 with your GPU's architecture and MAX_JOBS to however many CPU cores you have. I can only confirm it will work on gfx11+. The patch just changes the C++ standard from c++20 to c++17 to allow it to build.

python-flash-attention.tar.gz

If you aren't on Arch, you can generally just follow the commands and install the python wheel file afterwards, in your virtualenv if you're using one. You can clone the repo with git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git -b howiejayz/navi_support --depth=1 in this case.

Now for webui. You will have to use a patch that has been closed since it will be obsolete once AMD finishes xformers support.

https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11902.patch

That's the link to the raw patch file, use patch -p1 < 11902.patch in the webui directory to apply it to your webui. Please run patch -p1 --dry-run < 11902.patch first so it won't screw up your installation if it doesn't apply correctly. We're not done yet, however.

AUTOMATIC1111/stable-diffusion-webui#11902

In the discussion for this patch, I posted a long comment on getting it to work. (AUTOMATIC1111/stable-diffusion-webui#11902 (comment)). The important info from that is the 2 code blocks you need to change manually.

After that, add --flash-attn to your command line arguments and it should work. If you get slower or the same speed, flash attention isn't working. You may get a HIP OOM at the end of generation if you're using a higher resolution than usual, as it needs to switch back to SDP at the end due to not supporting a head dim over 128.

If you get an error involving a flash-attn so not loading, rebuild the PKGBUILD but change CC=clang and CXX=clang++ to CC=gcc and CXX=g++.

@Beinsezii
Copy link

Beinsezii commented Jan 4, 2024

Switching setup.py to c++17 built successfully on gfx1100

Seems to work in transformers, as a 7b model OOMs @ >8k context using the default attention but doesn't even crack 20 gigs with FA2 enabled. Interestingly I lose about 1 t/s though?

I'll have to see if I can monkeypatch it into Diffusers...

@Wintoplay
Copy link

@gel-crabs that and tried FLASH_ATTENTION_INTERNAL_USE_RTN=1 pip install .
it just say FFFFF result for the testing

(I use Debian.)
I have not tried the SD patch though cuz I want it for inference of LLM.

@gel-crabs
Copy link

@gel-crabs that and tried FLASH_ATTENTION_INTERNAL_USE_RTN=1 pip install . it just say FFFFF result for the testing

(I use Debian.) I have not tried the SD patch though cuz I want it for inference of LLM.

Do you mean the unit testing? For that you need to export FLASH_ATTENTION_INTERNAL_UNIT_TEST_MODE=1 and FLASH_ATTENTION_INTERNAL_DETERMINISTIC=1.

You should also set your GPU_ARCHS to your GPU architecture (gfx1100, gfx1101, etc.) and try building with GCC and Clang. I can also only guarantee this will work on ROCM 5.7 and up.

For anything other than SD webui, you will likely have to create a forward function yourself, as it is a PyTorch extension and isn't integrated into PyTorch yet. The implementation is here, but keep in mind it requires my changes as well:

https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/11902/files

@j-dominguez9
Copy link

I don't have the knowledge to contribute to this issue, but I'm really rooting for this support feature!

@Kademo15
Copy link

Kademo15 commented Jan 5, 2024

You may get a HIP OOM at the end of generation if you're using a higher resolution than usual, as it needs to switch back to SDP at the end due to not supporting a head dim over 128.

Could you provide me what you found. What‘s the size you can go before it has to fallback to sdp.

@feffy380
Copy link

feffy380 commented Jan 5, 2024

I've been using the howiejayz/navi_support branch on here with stable-diffusion-webui for a few weeks now. The implementation is perfect.

On an RX 7800 XT, it speeds it/s up from 1.75 it/s to 2 it/s, all while massively decreasing VRAM usage.

I found larger resolutions benefit more. On SD1.5, 1088x1088 on an RX 7900 XTX went from 1.67 it/s to 3 it/s while 512px was a more modest 16 it/s to 18 it/s.
VRAM usage also drops dramatically. Generating a 1088x1088 image goes from 18GB down to about 6GB. I don't see any spike at the end, though. Is it specific to SDXL maybe?

Note: If using a venv I found I had to build the wheel with the venv activated. Otherwise, the library complained about undefined symbols and failed to load.

@gel-crabs
Copy link

gel-crabs commented Jan 5, 2024

You may get a HIP OOM at the end of generation if you're using a higher resolution than usual, as it needs to switch back to SDP at the end due to not supporting a head dim over 128.

Could you provide me what you found. What‘s the size you can go before it has to fallback to sdp.

I think I explained wrong; it will always fall back to SDP at the very end of generation (after all the steps have finished), resolution doesn't factor into it.

What I meant is that the massively decreased VRAM usage will (currently..?) not allow you to use a higher resolution than with regular SDP attention, as the VRAM usage will jump back up at the end.

However, it most likely could... if AMD hooked up CK's Navi branch to their new xformers port. ;)

@gel-crabs
Copy link

gel-crabs commented Jan 14, 2024

Also note: the switch back to SDP at the end (I believe only with SDXL) can be prevented by switching from full VAE to TAESD, or (I assume) a tiled VAE implementation.

This allows a 1024x1024 image to be upscaled to 2048x2048 with SDXL on an RX 7800 XT with 16GB of VRAM.

@sabreshao sabreshao added the navi hardware label Jan 16, 2024
@sancspro
Copy link

sancspro commented Mar 13, 2024

Oh crap, sorry. It was working on my setup so I had no idea about these issues. I updated the files I uploaded above with fixes, is it producing garbage now?

Hi @gel-crabs
It is producing garbage even after replacing with your updated file.

Could be that I did not build flash-attn wheel correctly.

This is what I used to build it:

cd stable-diffusion-webui
python -m venv venv
source venv/bin/activate
pip install -U git+https://github.com/ROCm/flash-attention@howiejay/navi_support

Do you think this is correct?

@gel-crabs
Copy link

gel-crabs commented Mar 14, 2024

Oh crap, sorry. It was working on my setup so I had no idea about these issues. I updated the files I uploaded above with fixes, is it producing garbage now?

Hi @gel-crabs It is producing garbage even after replacing with your updated file.

Could be that I did not build flash-attn wheel correctly.

This is what I used to build it:

cd stable-diffusion-webui
python -m venv venv
source venv/bin/activate
pip install -U git+https://github.com/ROCm/flash-attention@howiejay/navi_support

Do you think this is correct?

Alright, I think I've got it. This just copy-pastes @Beinsezii's hijack method in with a small fix for the SDP attnblock, so just change from flash-attn back to SDP, then remove the --flash-attn option from modules/cmd_args.py and your webui startup.

This eschews sub-quadratic altogether (without a memory spike, somehow), so it shouldn't be producing garbage anymore. Thank you so much, @Beinsezii

sd_hijack_optimizations.py.txt

@Beinsezii
Copy link

I forgot to mention here but for ComfyUI people in this thread I already made a ComfyUI addon a few weeks ago that's basically just the SDPA hijack plonked into a fake node.

This just copy-pastes @Beinsezii's hijack method in with a small fix for the SDP attnblock

What's actually changed? It looks like my monkey patch is the same.

This eschews sub-quadratic altogether (without a memory spike, somehow), so it shouldn't be producing garbage anymore. Thank you so much, @Beinsezii

Subquad should use less memory than SDPA, but if the tile size is configured to be high enough it can end up using about the same. I've never looked at Auto's code so I can't say what goes on there.

I'd still recommend SDPA since it's basically the same function and really only the VAEs have memory issues when I tested up to 3840x2160. ComfyUI and Diffusers both have a tiled VAE flag which fixes that more or less flawlessly at a mild speed cost; I'm assuming Auto does too.

I've tested the current SDPA hijack on SD 1.5, 2.1, XL, Stable Cascade, and Pixart Alpha to great success so there shouldn't be anything broken as long as Auto's code is good. Even img2img and multi fused Lora seems to behave well.

@gel-crabs
Copy link

gel-crabs commented Mar 14, 2024

I forgot to mention here but for ComfyUI people in this thread I already made a ComfyUI addon a few weeks ago that's basically just the SDPA hijack plonked into a fake node.

This just copy-pastes @Beinsezii's hijack method in with a small fix for the SDP attnblock

What's actually changed? It looks like my monkey patch is the same.

This eschews sub-quadratic altogether (without a memory spike, somehow), so it shouldn't be producing garbage anymore. Thank you so much, @Beinsezii

Subquad should use less memory than SDPA, but if the tile size is configured to be high enough it can end up using about the same. I've never looked at Auto's code so I can't say what goes on there.

I'd still recommend SDPA since it's basically the same function and really only the VAEs have memory issues when I tested up to 3840x2160. ComfyUI and Diffusers both have a tiled VAE flag which fixes that more or less flawlessly at a mild speed cost; I'm assuming Auto does too.

I've tested the current SDPA hijack on SD 1.5, 2.1, XL, Stable Cascade, and Pixart Alpha to great success so there shouldn't be anything broken as long as Auto's code is good. Even img2img and multi fused Lora seems to behave well.

In the sdp_attnblock_forward function at the end:

q, k, v = (rearrange(t, 'b c h w -> b (h w) c') for t in (q, k, v))
out = rearrange(out, 'b (h w) c -> b c h w', h=h)

Changed to:

q, k, v = (rearrange(t, 'b c h w -> b (h w) 1 c') for t in (q, k, v))
out = rearrange(out, 'b (h w) 1 c -> b c h w', h=h)

This was in the original flash-attn patch, and I had to add it due to this line in the patch producing a "tuple index out of range" error at the VAE stage:

if query.shape[3] <= 128 and attn_mask is None:

@sancspro
Copy link

sancspro commented Mar 14, 2024

sd_hijack_optimizations.py.txt

Finally!!! Wow thank you so much guys. @Beinsezii and @gel-crabs

It works but please continue to read below.

When I replaced the existing module with this latest hijack file from @gel-crabs , it did not work straightaway. Auto1111 continued to produce garbage. Refusing to give up, I uninstalled flash-attn and reinstalled it again. This time I had overridden gfx version as below before reinstalling flash-attn:
export HSA_OVERRIDE_GFX_VERSION=11.0.0

And let it build wheels for flash-attn. After this auto1111 started working fine with the flash-attn magic.

For anyone who wants to try it. Here's the commands that I ran to get it working after replacing sd_hijack_optimizations.py file:

export HSA_OVERRIDE_GFX_VERSION=11.0.0
cd stable-diffusion-webui
python -m venv venv
source venv/bin/activate

pip install -U git+https://github.com/ROCm/flash-attention@howiejay/navi_support

python launch.py --opt-sdp-attention

Some quick comparison on how tremendous this is for AMD folks who use auto1111.

SD1.5 base model on Flash attention
1024 – 2.31 it/s VRAM used 5.5 GB
768 – 5.56 it/s VRAM 3.5 GB
512 – 15 it/s VRAM 3.5 GB
SDXL Juggernaut on Flash attention
1024 – 1.90 it/s VRAM used 9.0 GB

SD1.5 base model on Doggetx (default)
1024 – 1.50 it/s VRAM used 15.3 GB
768 – 3.86 it/s VRAM 8.5 GB
512 – 12 it/s VRAM 4.5 GB
SDXL Juggernaut on Doggetx (default)
1024 – 1.50 it/s VRAM used 9.5 GB

How ridiculous this looks when you see the VRAM usage.. LOL

My specs: Linux Mint, 7800XT, torch==2.3.0.dev20240314+rocm6.0.

@sancspro
Copy link

sancspro commented Mar 14, 2024

Big update for Stable Diffusion WebUI users!!
So as it turns out, it was actually super easy to replace SDP with Doggettx/sub-quadratic the whole time, I was just looking in the wrong place. XFormers does the exact same thing, just in the attnblock forward instead of the attention forward.
11902.patch.txt
Above is an updated version of the WebUI patch if you haven't applied it already (rename it to 11902.patch).
If you've already applied it, you can just replace sd_hijack_optimizations.py with this copy (rename it to sd_hijack_optimizations.py):
sd_hijack_optimizations.py.txt
Note: I chose Sub-quadratic as Doggettx has similar VRAM usage as SDP, and it only switches at the end of generation anyway so VRAM use matters more than speed here.

Just a note that this also works perfect.

@nonetrix
Copy link

Has anyone had luck with GFX 1030? I think I am out of luck :/

@Beinsezii
Copy link

RDNA ≤ 2 cards don't have WMMAs so I'm not sure they'll be supported anytime soon.

@hbfreed
Copy link

hbfreed commented Apr 22, 2024

@xzuyn I'm having the same error, did you find a fix for it? it's on the backward pass...

@nonetrix
Copy link

RDNA ≤ 2 cards don't have WMMAs so I'm not sure they'll be supported anytime soon.

Is this a absolutely required? Could it be forked to use some other implementation that does functionally the same thing?

@Beinsezii
Copy link

Is this a absolutely required? Could it be forked to use some other implementation that does functionally the same thing?

You could always use a pure pytorch impl like sub quadratic. Tune the tile sizes to your GPU and it should perform okayish and not OOM on large shapes.

@Googulator
Copy link

Is this a absolutely required? Could it be forked to use some other implementation that does functionally the same thing?

You could always use a pure pytorch impl like sub quadratic. Tune the tile sizes to your GPU and it should perform okayish and not OOM on large shapes.

Any options for something similar, but for LLMs / Llama, rather than Stable Diffusion?

@Beinsezii
Copy link

You could always use a pure pytorch impl like sub quadratic. Tune the tile sizes to your GPU and it should perform okayish and not OOM on large shapes.

Any options for something similar, but for LLMs / Llama, rather than Stable Diffusion?

Best Navi2 option would probably be Exllama2's Q4 context quantization. You still bleed iteration speed but the memory climb is greatly reduced.

VLLM has a triton impl that works great but I don't know if it'll compile against Navi 2 considering the dockerfile only targets the Instinct cards and Navi 3

@Googulator
Copy link

Navi 3 / GFX1100 is what we use, and it does work for inference, but not training.

@Beinsezii
Copy link

Navi cards are outright unsuitable for training most FFTs and new models right now. There's barely functional flash forward passes in a few spots but backwards is currently completely unsupported by AMD as far as I know.

You can make it work with deepspeed, lora, or other tricks but it's still very far from being efficient.

@Beinsezii
Copy link

I'll add the DaoLab triton impl runs okay-ish for me in my diffusion app. Because it's pure triton I think backwards should work? I'm not sure how useful it would be for LLMs.

@ZhenyaPav
Copy link

I'll add the DaoLab triton impl runs okay-ish for me in my diffusion app. Because it's pure triton I think backwards should work? I'm not sure how useful it would be for LLMs.

Did you manage to install flash attention from their repo? It errored out for me

@Beinsezii
Copy link

Did you manage to install flash attention from their repo?

The triton file I linked is usable directly as a standalone file. The repo's HIP kernels only support Instinct GPUs

@gel-crabs
Copy link

gel-crabs commented Aug 20, 2024

@Beinsezii

Big update for those using the CK-based version for forward inference (i.e. Stable Diffusion):

I managed to get head dim = 512 working with Flash Attention, which means that switching to SDP/sub-quadratic for the VAE stage isn't needed anymore. I've tested it for the past 2 weeks; no OOMs, no difference in images. It actually works.

gel-crabs@05aa137

It can be installed with:

pip install -U git+https://github.com/gel-crabs/flash-attention-gfx11@headdim512

@hackey
Copy link

hackey commented Aug 31, 2024

@Beinsezii

Big update for those using the CK-based version for forward inference (i.e. Stable Diffusion):

I managed to get head dim = 512 working with Flash Attention, which means that switching to SDP/sub-quadratic for the VAE stage isn't needed anymore. I've tested it for the past 2 weeks; no OOMs, no difference in images. It actually works.

gel-crabs@05aa137

It can be installed with:

pip install -U git+https://github.com/gel-crabs/flash-attention-gfx11@headdim512

What does "direct output" mean? Does it only work with Stable Diffussion? Will FA work with text generation models?

@gel-crabs
Copy link

@Beinsezii
Big update for those using the CK-based version for forward inference (i.e. Stable Diffusion):
I managed to get head dim = 512 working with Flash Attention, which means that switching to SDP/sub-quadratic for the VAE stage isn't needed anymore. I've tested it for the past 2 weeks; no OOMs, no difference in images. It actually works.
gel-crabs@05aa137
It can be installed with:
pip install -U git+https://github.com/gel-crabs/flash-attention-gfx11@headdim512

What does "direct output" mean? Does it only work with Stable Diffussion? Will FA work with text generation models?

What do you mean by direct output?

@sdfasfsdfasfasafd
Copy link

Switching setup.py to c++17 built successfully on gfx1100

Seems to work in transformers, as a 7b model OOMs @ >8k context using the default attention but doesn't even crack 20 gigs with FA2 enabled. Interestingly I lose about 1 t/s though?

I'll have to see if I can monkeypatch it into Diffusers...

Hello, could you tell me how to switching setup.py to c++17 built? Are there any documents? Thank you very much.

@sleppyrobot
Copy link

Hi @gel-crabs i tired your version of FA along with the AMD_Go_Fast.py patch for comfyui.

In the patch there is some code that references 128 (see below)

if query.shape[3] <= 128 and attn_mask is None:

Is this suppose to work if i change it to 512 ? Because if i do it triggers a long error, this is the last line.
Was using SD1.5 model.

DeviceGroupedQueryAttentionForward_Wmma, 256, 128, 128, 64, 8, 8, 128, 64, 64, 8, Default, ASpecDefault, B0SpecDefault, B1SpecDefault, CSpecDefault, MaskDisabled> AEnableLds: 0, B0EnableLds: 1, B1EnableLds: 1, NumPrefetch: 1, LoopScheduler: Default, PipelineVersion: v1 does not support this problem

Not really sure if/how these things work together or if they can with minimal effort.
Not a big deal either way.

@sleppyrobot
Copy link

sleppyrobot commented Sep 19, 2024

Not sure how i fixed this problem, but i no longer get this error with sd1.5 models.

@jnolck
Copy link

jnolck commented Nov 22, 2024

@Beinsezii

Big update for those using the CK-based version for forward inference (i.e. Stable Diffusion):

I managed to get head dim = 512 working with Flash Attention, which means that switching to SDP/sub-quadratic for the VAE stage isn't needed anymore. I've tested it for the past 2 weeks; no OOMs, no difference in images. It actually works.

gel-crabs@05aa137

It can be installed with:

pip install -U git+https://github.com/gel-crabs/flash-attention-gfx11@headdim512

I couldn't anyone else's flash-attention to compile. Your' s worked like a charm. Nice 0.7it/s boost and lower vram usage. Used in combination with the amd-go-fast script for confyui that omeone else mentioned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
navi hardware
Projects
None yet
Development

No branches or pull requests