-
-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bug]: For RDNA3 (navi31; gfx1100) VLLM_USE_TRITON_FLASH_ATTN=0 currently must be forced #4514
Comments
After taking a deep look into the code and testing Flash Attention support on AMD GPUs here is what I found: AMD Instinct GPUs, gfx90a and gfx942 (MI210, MI250, MI300), support Flash Attention by way of specially written Composable Kernel libraries. Although I haven't tested this myself it is working and there are performance numbers on the 2-3x speedup vLLM gives you using CK Flash Attention. Radeon RDNA3 GPUs, 7900 XTX and W7900 (gfx1100), lack the nessecary Composable Kernel libraries to use the above mentioned Flash Attention mechanism and thus the engineers at AMD opted for these GPUs to use an implemenation of Flash Attention written in OpenAI's Triton. This Triton Flash Attention is supposed to be working, but all tests I've done (usuing various different branches and docker builds) and using Flash Attention forward pass support for RDNA3 was added thanks to howiejay however this implementation no longer works in my testing as it fails to run the hipify_python patch and build on newer versions of pytorch+rocm (tried on In summary the only way it seems to get vLLM working on Radeon and Radeon Pro graphics cards at the moment is to build without CK Flash Attention support The vLLM repos I've already tried are: Commands used to run vLLM docker image and server were as follows (tried a few other variations of the below commands like changing smh-size or --max-model-len with no luck):
Asking that the engineers at AMD look into this and assist in troubleshooting/getting this working for Radeon GPUs (Navi3). |
if possible, can you try building triton from source? |
That won't work I think. There's a related Flash Attention discussion on gfx1100 here: ROCm/aotriton#16 although according to this, Navi support was upstreamed last month and the appropriate place to file any navi31 Triton issues is the main repo: https://github.com/openai/triton (The vLLM bug atm is just that it's not checking for gfx1100 correctly, it shouldn't be trying to use the Triton FA at all?) |
The howiejay branch should build fine on the latest torch stable running ROCm 6. I have py3.11 and py3.12 wheels built against gfx1100 and ROCm 6.0 here. All I run is pip wheel git+https://github.com/ROCm/flash-attention@howiejay/navi_support --no-deps in my virtualenvs to produce the wheels.
Upate: There's an internal gate against using the CK FA for Navi even if its installed because there's no Additionally I built ROCM/triton from source as of an hour ago and it still just sits peaking one thread for a small eternity before eventually being killed for blowing up the stack. I guess a person could try to increase the stack size but I really feel like something's not working... |
I think I narrowed it to this autotune: Disabling that and I can run without Patch on top of v0.4.2 if someone else wants to play with it. diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py
index 11476641..d5f6bbec 100644
--- a/vllm/attention/ops/triton_flash_attention.py
+++ b/vllm/attention/ops/triton_flash_attention.py
@@ -219,16 +219,16 @@ def _attn_fwd_inner(
num_stages=1,
num_warps=8,
),
- triton.Config(
- {
- "BLOCK_M": 128,
- "BLOCK_N": 128,
- "waves_per_eu": 2,
- "PRE_LOAD_V": False,
- },
- num_stages=1,
- num_warps=4,
- ),
+ # triton.Config(
+ # {
+ # "BLOCK_M": 128,
+ # "BLOCK_N": 128,
+ # "waves_per_eu": 2,
+ # "PRE_LOAD_V": False,
+ # },
+ # num_stages=1,
+ # num_warps=4,
+ # ),
triton.Config(
{
"BLOCK_M": 256, |
the upstreaming triton support navi3 but attention performance is slow |
Alright I tried with stable triton and the ROCm triton fork. My patch only helped the official nightly run without hanging. pip uninstall pytorch-triton-rocm -y; pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps There might be more configs that need to be disabled to run stable triton? A person could maybe just disable every config with a block dim ≥ 128 and it'd probably work everywhere. I think navi favors the small ones anyways? I also found triton is indeed much faster than naive once you stack the context. |
Is there a solution available? I also encountered this issue with AMD w6800 |
I'd wager it's related to ROCm/triton#596
Commenting out every autotune with a block size ≥128 allows it to compile using |
@Beinsezii Thank you very much for your help. After operating according to your method, there will be no error before the error message (error: triton_flash-attention. py: 211:0: stack frame size). However, the answer generation process runs very slowly, which is equivalent to the speed at VLLM-USE-TRITON-FLASH.ATTN=0. Is it a problem with my graphics card? AMD W6800 |
The W6800 has no WMMA accelerators so I'm not sure it'll be faster. It should still use less memory for long context models though. |
With ROCm/triton#596 closed I decided to rebuild build |
The problem is that rocblas is not supported on navi architectures by rocm. Hence FA wont work in general I think. |
Your current environment
🐛 Describe the bug
I'm able to built the ROCM docker image for AMD via the latest docs: https://docs.vllm.ai/en/latest/getting_started/amd-installation.html#option-1-build-from-source-with-docker-recommended
I am using a W7900 (RDNA3; navi31; gfx1100) and therefore use
BUILD_FA="0"
sans Flash Attention.When I run any script (say
benchmarks/benchmark_latency.py
), I get this error:It's trying to use Triton which seems to use an implementation of flash attention?
Stepping through the code it goes through
selector.py
:And that sends it to the ROCm backend:
In the backend, there is a switch for
navi3x
:vllm/vllm/attention/backends/rocm_flash_attn.py
Line 167 in d6f4bd7
It sets
self.use_naive_attn = True
iftorch.cuda.get_device_capability()[0] == 11
(gfx11xx) - so far so good, but that branch only executes ifself.use_triton_flash_attn
isfalse
which is set byVLLM_USE_TRITON_FLASH_ATTN
and defaults to True.So, in order to get this running you need to have
VLLM_USE_TRITON_FLASH_ATTN=0
in your env.This isn't in the docs, or set by default when you
BUILD_FA="0"
Presumably, the correct way to fix is for the ROCm implmentation to do correct navi3x checking and set the appropriate lib/path to use based on which kernel is currently support?
The text was updated successfully, but these errors were encountered: