Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ if seq_len > keep_window_size:
attention_mask.zero_()
attention_mask.scatter(-1, topk_indices, 1.0)

# Select backend
flash_dmattn_func = flash_dmattn_func_auto(backend="cuda")
# Select backend (auto-selects the best available backend)
flash_dmattn_func = flash_dmattn_func_auto() # Automatically chooses CUDA, Triton, or Flex

# Run Flash Dynamic Mask Attention
output = flash_dmattn_func(
Expand Down Expand Up @@ -232,12 +232,40 @@ python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"
```python
# Test basic import
try:
from flash_dmattn import flash_dmattn_func, get_available_backends
from flash_dmattn import flash_dmattn_func_auto, get_available_backends
print("✅ Flash Dynamic Mask Attention imported successfully")
print(f"Available backends: {get_available_backends()}")

# Test auto backend selection (recommended)
func = flash_dmattn_func_auto()
print("✅ Auto backend selection works")
except ImportError as e:
print(f"❌ Import failed: {e}")
print("Please ensure the package is properly installed with: pip install -e .")
except RuntimeError as e:
print(f"❌ Backend error: {e}")
# The error message will provide specific guidance on what to install
```

**CUDA Extension Issues**
```python
# If you get "CUDA flash_dmattn_func is not available" error:
# This usually means the CUDA extension was partially installed

# Solution 1: Rebuild the CUDA extension
# pip install -e . --force-reinstall

# Solution 2: Use alternative backends
# pip install triton # For Triton backend
# pip install transformers # For Flex backend

# Test with fallback backends
from flash_dmattn import flash_dmattn_func_auto
try:
func = flash_dmattn_func_auto() # Will auto-select working backend
print("✅ Using backend:", func.__name__)
except RuntimeError as e:
print(f"❌ No backends available: {e}")
```

**Performance Issues**
Expand Down
24 changes: 18 additions & 6 deletions flash_dmattn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,15 @@
]


def _is_cuda_fully_available():
"""Check if CUDA backend is fully available (both module and functions)."""
return CUDA_AVAILABLE and flash_dmattn_func is not None


def get_available_backends():
"""Return a list of available backends."""
backends = []
if CUDA_AVAILABLE:
if _is_cuda_fully_available():
backends.append("cuda")
if TRITON_AVAILABLE:
backends.append("triton")
Expand All @@ -94,21 +99,28 @@ def flash_dmattn_func_auto(backend: Optional[str] = None, **kwargs):
The attention function for the specified or auto-selected backend.
"""
if backend is None:
# Auto-select backend
if CUDA_AVAILABLE:
# Auto-select backend - use the first fully working backend
if _is_cuda_fully_available():
backend = "cuda"
elif TRITON_AVAILABLE:
backend = "triton"
elif FLEX_AVAILABLE:
backend = "flex"
else:
raise RuntimeError("No flash attention backend is available. Please install at least one of: triton, transformers, or build the CUDA extension.")
# Provide helpful error message based on what's partially available
error_parts = ["No flash attention backend is fully available."]
if CUDA_AVAILABLE and flash_dmattn_func is None:
error_parts.append("CUDA extension was found but interface functions are not available - please rebuild the CUDA extension with: pip install -e .")
else:
error_parts.append("CUDA extension is not built - please install with: pip install -e .")
error_parts.append("Alternatively, install alternative backends: pip install triton (for Triton backend) or pip install transformers (for Flex backend).")
raise RuntimeError(" ".join(error_parts))

if backend == "cuda":
if not CUDA_AVAILABLE:
raise RuntimeError("CUDA backend is not available. Please build the CUDA extension.")
raise RuntimeError("CUDA backend is not available. Please build the CUDA extension with: pip install -e .")
if flash_dmattn_func is None:
raise RuntimeError("CUDA flash_dmattn_func is not available. Please check the installation.")
raise RuntimeError("CUDA extension was found but interface functions are not available. This may indicate an incomplete installation. Please rebuild the CUDA extension with: pip install -e .")
return flash_dmattn_func

elif backend == "triton":
Expand Down