Skip to content

Commit 1cb0cc2

Browse files
authored
[FIX] Make flash_attn optional (vllm-project#3269)
1 parent 99c3cfb commit 1cb0cc2

File tree

5 files changed

+41
-78
lines changed

5 files changed

+41
-78
lines changed

.gitignore

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,3 @@ _build/
184184

185185
# Benchmark dataset
186186
*.json
187-
188-
# Third-party Python packages.
189-
vllm/thirdparty_files/

setup.py

Lines changed: 3 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import os
44
import re
55
import subprocess
6-
import sys
76
import warnings
87
from pathlib import Path
98
from typing import List, Set
@@ -15,8 +14,6 @@
1514
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
1615

1716
ROOT_DIR = os.path.dirname(__file__)
18-
# This is a temporary directory to store third-party packages.
19-
THIRDPARTY_SUBDIR = "vllm/thirdparty_files"
2017

2118
# If you are developing the C++ backend of vLLM, consider building vLLM with
2219
# `python setup.py develop` since it will give you incremental builds.
@@ -327,46 +324,8 @@ def get_torch_arch_list() -> Set[str]:
327324
"nvcc": NVCC_FLAGS_PUNICA,
328325
},
329326
))
330-
331-
# Download the FlashAttention package.
332-
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/setup.py#L518-L530
333-
flash_attn_version = "2.5.6"
334-
install_dir = os.path.join(ROOT_DIR, THIRDPARTY_SUBDIR)
335-
subprocess.check_call(
336-
[
337-
sys.executable,
338-
"-m",
339-
"pip",
340-
"install",
341-
"-q",
342-
f"--target={install_dir}",
343-
"einops", # Dependency of flash-attn.
344-
f"flash-attn=={flash_attn_version}",
345-
"--no-dependencies", # Required to avoid re-installing torch.
346-
],
347-
env=dict(os.environ, CC="gcc"),
348-
)
349-
350-
# Copy the FlashAttention package into the vLLM package after build.
351-
class build_ext(BuildExtension):
352-
353-
def run(self):
354-
super().run()
355-
target_dir = os.path.join(self.build_lib, THIRDPARTY_SUBDIR)
356-
if not os.path.exists(target_dir):
357-
os.makedirs(target_dir)
358-
self.copy_tree(install_dir, target_dir)
359-
360-
class BinaryDistribution(setuptools.Distribution):
361-
362-
def has_ext_modules(self):
363-
return True
364-
365-
else:
366-
build_ext = BuildExtension
367-
BinaryDistribution = setuptools.Distribution
368-
if _is_neuron():
369-
neuronxcc_version = get_neuronxcc_version()
327+
elif _is_neuron():
328+
neuronxcc_version = get_neuronxcc_version()
370329

371330
vllm_extension_sources = [
372331
"csrc/cache_kernels.cu",
@@ -509,7 +468,6 @@ def get_requirements() -> List[str]:
509468
python_requires=">=3.8",
510469
install_requires=get_requirements(),
511470
ext_modules=ext_modules,
512-
cmdclass={"build_ext": build_ext} if not _is_neuron() else {},
513-
distclass=BinaryDistribution,
471+
cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {},
514472
package_data=package_data,
515473
)

vllm/__init__.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,12 @@
11
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
22

3-
4-
# Adapted from https://github.com/ray-project/ray/blob/f92928c9cfcbbf80c3a8534ca4911de1b44069c0/python/ray/__init__.py#L11
5-
def _configure_system():
6-
import os
7-
import sys
8-
9-
# Importing flash-attn.
10-
thirdparty_files = os.path.join(os.path.abspath(os.path.dirname(__file__)),
11-
"thirdparty_files")
12-
sys.path.insert(0, thirdparty_files)
13-
14-
15-
_configure_system()
16-
# Delete configuration function.
17-
del _configure_system
18-
19-
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402
20-
from vllm.engine.async_llm_engine import AsyncLLMEngine # noqa: E402
21-
from vllm.engine.llm_engine import LLMEngine # noqa: E402
22-
from vllm.engine.ray_utils import initialize_cluster # noqa: E402
23-
from vllm.entrypoints.llm import LLM # noqa: E402
24-
from vllm.outputs import CompletionOutput, RequestOutput # noqa: E402
25-
from vllm.sampling_params import SamplingParams # noqa: E402
3+
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
4+
from vllm.engine.async_llm_engine import AsyncLLMEngine
5+
from vllm.engine.llm_engine import LLMEngine
6+
from vllm.engine.ray_utils import initialize_cluster
7+
from vllm.entrypoints.llm import LLM
8+
from vllm.outputs import CompletionOutput, RequestOutput
9+
from vllm.sampling_params import SamplingParams
2610

2711
__version__ = "0.3.3"
2812

vllm/model_executor/layers/attention/attention.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
"""Attention layer."""
2+
from functools import lru_cache
23
from typing import List, Optional
34

45
import torch
56
import torch.nn as nn
67

8+
from vllm.logger import init_logger
79
from vllm.model_executor.input_metadata import InputMetadata
810
from vllm.utils import is_hip
911

12+
logger = init_logger(__name__)
13+
1014

1115
class Attention(nn.Module):
1216
"""Attention layer.
@@ -30,17 +34,12 @@ def __init__(
3034
sliding_window: Optional[int] = None,
3135
) -> None:
3236
super().__init__()
33-
if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and
34-
torch.get_default_dtype() in (torch.float16, torch.bfloat16)):
35-
# Ampere or later NVIDIA GPUs.
36-
# NOTE(woosuk): FlashAttention does not support FP32.
37+
if _use_flash_attn():
3738
from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend
3839
self.backend = FlashAttentionBackend(num_heads, head_size, scale,
3940
num_kv_heads, alibi_slopes,
4041
sliding_window)
4142
else:
42-
# Turing and Volta NVIDIA GPUs or AMD GPUs.
43-
# Or FP32 on any GPU.
4443
from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend
4544
self.backend = XFormersBackend(num_heads, head_size, scale,
4645
num_kv_heads, alibi_slopes,
@@ -57,3 +56,29 @@ def forward(
5756
) -> torch.Tensor:
5857
return self.backend.forward(query, key, value, key_cache, value_cache,
5958
input_metadata)
59+
60+
61+
@lru_cache(maxsize=1)
62+
def _use_flash_attn() -> bool:
63+
try:
64+
import flash_attn # noqa: F401
65+
except ImportError:
66+
logger.info("flash_attn is not found. Using xformers backend.")
67+
return False
68+
69+
if is_hip():
70+
# AMD GPUs.
71+
return False
72+
if torch.cuda.get_device_capability()[0] < 8:
73+
# Volta and Turing NVIDIA GPUs.
74+
logger.info("flash_attn is not supported on Turing or older GPUs. "
75+
"Using xformers backend.")
76+
return False
77+
if torch.get_default_dtype() not in (torch.float16, torch.bfloat16):
78+
logger.info(
79+
"flash_attn only supports torch.float16 or torch.bfloat16. "
80+
"Using xformers backend.")
81+
return False
82+
83+
logger.info("Using flash_attn backend.")
84+
return True

vllm/model_executor/layers/attention/backends/flash_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Attention layer with Flash and PagedAttention."""
22
from typing import List, Optional
33

4-
# NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/.
54
from flash_attn import flash_attn_func
65
import torch
76

0 commit comments

Comments
 (0)