Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: For RDNA3 (navi31; gfx1100) VLLM_USE_TRITON_FLASH_ATTN=0 currently must be forced #4514

Open
lhl opened this issue May 1, 2024 · 13 comments
Labels
bug Something isn't working rocm

Comments

@lhl
Copy link

lhl commented May 1, 2024

Your current environment

Collecting environment information...
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/cuda/__init__.py:611: UserWarning: Can't initialize NVML
  warnings.warn("Can't initialize NVML")
PyTorch version: 2.1.1+git011de5c
Is debug build: False
CUDA used to build PyTorch: N/A
ROCM used to build PyTorch: 6.0.32830-d62f6a171

OS: Ubuntu 20.04.6 LTS (x86_64)
GCC version: (Ubuntu 9.4.0-1ubuntu1~20.04.2) 9.4.0
Clang version: 17.0.0 (https://github.com/RadeonOpenCompute/llvm-project roc-6.0.0 23483 7208e8d15fbf218deb74483ea8c549c67ca4985e)
CMake version: version 3.29.2
Libc version: glibc-2.31

Python version: 3.9.18 (main, Sep 11 2023, 13:41:44)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-6.5.0-28-generic-x86_64-with-glibc2.31
Is CUDA available: True
CUDA runtime version: 10.1.243
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: AMD Radeon PRO W7900NoGCNArchNameOnOldPyTorch
Nvidia driver version: Could not collect
cuDNN version: Could not collect
HIP runtime version: 6.0.32830
MIOpen runtime version: 3.0.0
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Byte Order:                         Little Endian
Address sizes:                      48 bits physical, 48 bits virtual
CPU(s):                             12
On-line CPU(s) list:                0-11
Thread(s) per core:                 2
Core(s) per socket:                 6
Socket(s):                          1
NUMA node(s):                       1
Vendor ID:                          AuthenticAMD
CPU family:                         25
Model:                              80
Model name:                         AMD Ryzen 5 5600G with Radeon Graphics
Stepping:                           0
CPU MHz:                            3558.363
CPU max MHz:                        4464.0000
CPU min MHz:                        400.0000
BogoMIPS:                           7799.51
Virtualization:                     AMD-V
L1d cache:                          192 KiB
L1i cache:                          192 KiB
L2 cache:                           3 MiB
L3 cache:                           16 MiB
NUMA node0 CPU(s):                  0-11
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Not affected
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Vulnerable: Safe RET, no microcode
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Retpolines, IBPB conditional, IBRS_FW, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Not affected
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm

Versions of relevant libraries:
[pip3] mypy==1.4.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.22.4
[pip3] torch==2.1.1+git011de5c
[pip3] torchvision==0.16.1+fdea156
[pip3] triton==2.1.0
[conda] No relevant packagesROCM Version: 6.0.32830-d62f6a171
Neuron SDK Version: N/A
vLLM Version: 0.4.1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
Could not collect

🐛 Describe the bug

I'm able to built the ROCM docker image for AMD via the latest docs: https://docs.vllm.ai/en/latest/getting_started/amd-installation.html#option-1-build-from-source-with-docker-recommended

I am using a W7900 (RDNA3; navi31; gfx1100) and therefore use BUILD_FA="0" sans Flash Attention.

When I run any script (say benchmarks/benchmark_latency.py), I get this error:

error: triton_flash_attention.py:211:0: stack frame size (277288) exceeds limit (262136) in function 'attn_fwd_0d1d2d3de45de6d7de8de9de10c11de12de13de14c15de16de17de18c19de20de21de22c23de24de25de26de27d28d29303132de'

It's trying to use Triton which seems to use an implementation of flash attention?

Stepping through the code it goes through selector.py:

INFO 05-01 05:11:43 selector.py:59] flash_atten is not supported on NAVI GPUs.

And that sends it to the ROCm backend:

INFO 05-01 05:11:43 selector.py:38] Using ROCmFlashAttention backend.

In the backend, there is a switch for navi3x:

# if not using triton, navi3x not use flash-attn either

It sets self.use_naive_attn = True if torch.cuda.get_device_capability()[0] == 11 (gfx11xx) - so far so good, but that branch only executes if self.use_triton_flash_attn is false which is set by VLLM_USE_TRITON_FLASH_ATTN and defaults to True.

So, in order to get this running you need to have VLLM_USE_TRITON_FLASH_ATTN=0 in your env.

This isn't in the docs, or set by default when you BUILD_FA="0"

Presumably, the correct way to fix is for the ROCm implmentation to do correct navi3x checking and set the appropriate lib/path to use based on which kernel is currently support?

@lhl lhl added the bug Something isn't working label May 1, 2024
@farshadghodsian
Copy link

farshadghodsian commented May 18, 2024

After taking a deep look into the code and testing Flash Attention support on AMD GPUs here is what I found:

AMD Instinct GPUs, gfx90a and gfx942 (MI210, MI250, MI300), support Flash Attention by way of specially written Composable Kernel libraries. Although I haven't tested this myself it is working and there are performance numbers on the 2-3x speedup vLLM gives you using CK Flash Attention.

Radeon RDNA3 GPUs, 7900 XTX and W7900 (gfx1100), lack the nessecary Composable Kernel libraries to use the above mentioned Flash Attention mechanism and thus the engineers at AMD opted for these GPUs to use an implemenation of Flash Attention written in OpenAI's Triton. This Triton Flash Attention is supposed to be working, but all tests I've done (usuing various different branches and docker builds) and using VLLM_USE_TRITON_FLASH_ATTN=1 have the same "stack frame size exceeds limit" issue while trying to compile doing the Triton JIT compile at runtime. I am sure the compile is not failing due to system resources as I have tested this using the Radeon Pro W7900 on two powerful systems, Ryzen 9 7950x w/ 64GBs of RAM and a Threadripper Pro 5975wx w/ 128GBs of RAM, but in both cases the triton compile takes a really long time (upwards of several hours) and still fails with the same stack frame size error (see screenshot).

Screenshot from 2024-05-18 06-36-37

Flash Attention forward pass support for RDNA3 was added thanks to howiejay however this implementation no longer works in my testing as it fails to run the hipify_python patch and build on newer versions of pytorch+rocm (tried on rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1 and rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1).

In summary the only way it seems to get vLLM working on Radeon and Radeon Pro graphics cards at the moment is to build without CK Flash Attention support BUILD_FA="0" and disable the Triton Flash Attention implemenation VLLM_USE_TRITON_FLASH_ATTN=0. This results in vLLM running, but you do not get any of the speed ups that vLLM is known for and in my testing inference using vLLM is the same or slower than things like llama.cpp and Ollama.

The vLLM repos I've already tried are:
https://github.com/vllm-project/vllm (main branch)
https://github.com/ROCm/vllm (main, bf16_temp_fix_navi, TunableOp_Integration_ROCm6.0 branches)
https://github.com/hongxiayang/vllm (main branch)

Commands used to run vLLM docker image and server were as follows (tried a few other variations of the below commands like changing smh-size or --max-model-len with no luck):

# Run vllm-rocm Docker image 
docker run -it --network=host --device=/dev/kfd --device=/dev/dri --group-add=video --ipc=host \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G --name vllm-rocm -v /home/${USER}/Downloads/models:/app/model \
vllm-rocm bash

# Run vllm api server
VLLM_USE_TRITON_FLASH_ATTN=1 CUDA_VISIBLE_DEVICES=0 python -m vllm.entrypoints.openai.api_server --max-model-len 3072 --download-dir /app/model --quantization=gptq --tensor-parallel-size=1 --enforce-eager --trust-remote-code --dtype=auto --kv-cache-dtype=auto --quantization-param-path=None --device=cuda --block-size=16 --model TechxGenus/Meta-Llama-3-70B-Instruct-GPTQ

Asking that the engineers at AMD look into this and assist in troubleshooting/getting this working for Radeon GPUs (Navi3).

@DhruvDh
Copy link

DhruvDh commented May 19, 2024

if possible, can you try building triton from source?

@lhl
Copy link
Author

lhl commented May 19, 2024

That won't work I think. There's a related Flash Attention discussion on gfx1100 here: ROCm/aotriton#16 although according to this, Navi support was upstreamed last month and the appropriate place to file any navi31 Triton issues is the main repo: https://github.com/openai/triton

(The vLLM bug atm is just that it's not checking for gfx1100 correctly, it shouldn't be trying to use the Triton FA at all?)

@Beinsezii
Copy link

Beinsezii commented May 20, 2024

The howiejay branch should build fine on the latest torch stable running ROCm 6. I have py3.11 and py3.12 wheels built against gfx1100 and ROCm 6.0 here. All I run is

pip wheel git+https://github.com/ROCm/flash-attention@howiejay/navi_support --no-deps

in my virtualenvs to produce the wheels.

So I built vLLM with defaults then set VLLM_USE_TRITON_FLASH_ATTN=0 at runtime. On unquantized Llama3 8B I peaked at something like 1550 T/S with BS=96 and 0.95 memory allocation on a 7900 XTX 24G. 400 token response with a few hundred in context. Seems okay-ish?

Upate: There's an internal gate against using the CK FA for Navi even if its installed because there's no varlen_fwd() support. You can build and install the howiejay flash-attn fine but it seems to only be useful for diffusion models atm.

Additionally I built ROCM/triton from source as of an hour ago and it still just sits peaking one thread for a small eternity before eventually being killed for blowing up the stack. I guess a person could try to increase the stack size but I really feel like something's not working...

@Beinsezii
Copy link

I think I narrowed it to this autotune:
triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "waves_per_eu": 2, "PRE_LOAD_V": False}, num_stages=1, num_warps=4)

Disabling that and I can run without VLLM_USE_TRITON_FLASH_ATTN=0. I'm using triton nightly as of an hour ago to make sure it has any possible Navi fixes. Though if anything it feels slower? I'll try stable triton in a bit.

Patch on top of v0.4.2 if someone else wants to play with it.

diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py
index 11476641..d5f6bbec 100644
--- a/vllm/attention/ops/triton_flash_attention.py
+++ b/vllm/attention/ops/triton_flash_attention.py
@@ -219,16 +219,16 @@ def _attn_fwd_inner(
             num_stages=1,
             num_warps=8,
         ),
-        triton.Config(
-            {
-                "BLOCK_M": 128,
-                "BLOCK_N": 128,
-                "waves_per_eu": 2,
-                "PRE_LOAD_V": False,
-            },
-            num_stages=1,
-            num_warps=4,
-        ),
+        # triton.Config(
+        #     {
+        #         "BLOCK_M": 128,
+        #         "BLOCK_N": 128,
+        #         "waves_per_eu": 2,
+        #         "PRE_LOAD_V": False,
+        #     },
+        #     num_stages=1,
+        #     num_warps=4,
+        # ),
         triton.Config(
             {
                 "BLOCK_M": 256,

@sdli1995
Copy link

The howiejay branch should build fine on the latest torch stable running ROCm 6. I have py3.11 and py3.12 wheels built against gfx1100 and ROCm 6.0 here. All I run is

pip wheel git+https://github.com/ROCm/flash-attention@howiejay/navi_support --no-deps

in my virtualenvs to produce the wheels.

So I built vLLM with defaults then set VLLM_USE_TRITON_FLASH_ATTN=0 at runtime. On unquantized Llama3 8B I peaked at something like 1550 T/S with BS=96 and 0.95 memory allocation on a 7900 XTX 24G. 400 token response with a few hundred in context. Seems okay-ish?

Upate: There's an internal gate against using the CK FA for Navi even if its installed because there's no varlen_fwd() support. You can build and install the howiejay flash-attn fine but it seems to only be useful for diffusion models atm.

Additionally I built ROCM/triton from source as of an hour ago and it still just sits peaking one thread for a small eternity before eventually being killed for blowing up the stack. I guess a person could try to increase the stack size but I really feel like something's not working...

the upstreaming triton support navi3 but attention performance is slow

@Beinsezii
Copy link

Beinsezii commented May 22, 2024

Alright I tried with stable triton and the ROCm triton fork. My patch only helped the official nightly run without hanging.

pip uninstall pytorch-triton-rocm -y; pip install --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps

There might be more configs that need to be disabled to run stable triton? A person could maybe just disable every config with a block dim ≥ 128 and it'd probably work everywhere. I think navi favors the small ones anyways?

I also found triton is indeed much faster than naive once you stack the context.

@taikai-zz
Copy link

Is there a solution available? I also encountered this issue with AMD w6800

@Beinsezii
Copy link

I'd wager it's related to ROCm/triton#596

Is there a solution available? I also encountered this issue with AMD w6800

Commenting out every autotune with a block size ≥128 allows it to compile using pytorch-triton-rocm==2.3.1 for me on gfx1100.

triton-flash-stable-patch.txt

@taikai-zz
Copy link

@Beinsezii Thank you very much for your help. After operating according to your method, there will be no error before the error message (error: triton_flash-attention. py: 211:0: stack frame size). However, the answer generation process runs very slowly, which is equivalent to the speed at VLLM-USE-TRITON-FLASH.ATTN=0. Is it a problem with my graphics card? AMD W6800

@Beinsezii
Copy link

The W6800 has no WMMA accelerators so I'm not sure it'll be faster. It should still use less memory for long context models though.

@Beinsezii
Copy link

With ROCm/triton#596 closed I decided to rebuild build triton-lang/triton and was able to run VLLM_USE_TRITON_FLASH_ATTN=1 on an unmodified vllm 0.4.2 + gfx1100

@B0-B
Copy link

B0-B commented Sep 8, 2024

The problem is that rocblas is not supported on navi architectures by rocm. Hence FA wont work in general I think.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working rocm
Projects
None yet
Development

No branches or pull requests

8 participants