Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flash-attention-3 #33522

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/push-important-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ jobs:
run:
pytest -rsfE -m "flash_attn_test" --make-reports=${{ matrix.model-name }}_fa2_tests/ tests/${{ matrix.model-name }}/test_modeling_*

- name: Run FA3 tests
id: run_fa3_tests
run:
pytest -rsfE -m "flash_attn_3_test" --make-reports=${{ matrix.model-name }}_fa3_tests/ tests/${{ matrix.model-name }}/test_modeling_*

- name: "Test suite reports artifacts: ${{ matrix.model-name }}_fa2_tests"
if: ${{ always() }}
uses: actions/upload-artifact@v4
Expand Down
18 changes: 18 additions & 0 deletions docs/source/en/llm_optims.md
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,24 @@ model = AutoModelForCausalLM.from_pretrained(
)
```

### FlashAttention-3

FlashAttention and [FlashAttention-3](./perf_infer_gpu_one#flashattention-3) break up the attention computation into smaller chunks and reduces the number of intermediate read/write operations to GPU memory to speed up inference. FlashAttention-3 improves on FlashAttention-2 algorithm by taking advantage of new features on Hopper GPUs to maximize performance.

To use FlashAttention-3, set `attn_implementation="flash_attention_3"` in the [`~PreTrainedModel.from_pretrained`] method.

```py
from transformers import AutoModelForCausalLM, BitsAndBytesConfig

quant_config = BitsAndBytesConfig(load_in_8bit=True)
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_3",
)
```

### PyTorch scaled dot product attention

Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation.
Expand Down
136 changes: 136 additions & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,142 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
<img src="https://huggingface.co/datasets/ybelkada/documentation-images/resolve/main/llama-2-large-seqlen-padding.png">
</div>


## FlashAttention-3

<Tip>

FlashAttention-3 is experimental and may change considerably in future versions.

</Tip>

[FlashAttention-3](https://huggingface.co/papers/2407.08608) improves on FlashAttention-2 algorithm by taking advantage of new features on Hopper GPUs to maximize performance:

1. overlap overall computation and data movement via warp-specialization
2. interleave block-wise matmul and softmax operations
3. block quantization and incoherent processing that leverages hardware support for FP8 low-precision

FlashAttention-3 is currently supported for the following architectures:
* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
* [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel)
* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
* [GraniteMoe](https://huggingface.co/docs/transformers/model_doc/granitemoe#transformers.GraniteMoeModel)
* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava)
* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)
* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
* [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel)
* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
* [Qwen2VL](https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel)
* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
* [Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)

You can request to add FlashAttention-3 support for another model by opening a GitHub Issue or Pull Request.

Before you begin, make sure you have FlashAttention-3 installed.

<hfoptions id="install">
<hfoption id="NVIDIA">

```bash
git clone https://github.com/Dao-AILab/flash-attention
cd flash-attention/hopper
python setup.py install
```

</hfoption>
</hfoptions>

To enable FlashAttention-3, pass the argument `attn_implementation="flash_attention_3"` to [`~AutoModelForCausalLM.from_pretrained`]:

```python
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_3",
)
```

<Tip>

FlashAttention-3 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-3.

<br>

</Tip>

FlashAttention-3 can be combined with other optimization techniques like quantization to further speedup inference. For example, you can combine FlashAttention-3 with 8-bit or 4-bit quantization:

```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM

model_id = "tiiuae/falcon-7b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

# load in 8bit
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_8bit=True,
attn_implementation="flash_attention_3",
)

# load in 4bit
model = AutoModelForCausalLM.from_pretrained(
model_id,
load_in_4bit=True,
attn_implementation="flash_attention_3",
)
```

## PyTorch scaled dot product attention

PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ addopts = "--doctest-glob='**/*.md'"
doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
markers = [
"flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
"flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')",
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
"generate: marks tests that use the GenerationTesterMixin"
]
69 changes: 56 additions & 13 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
import torch.nn.functional as F

from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal
from .utils import is_flash_attn_2_available, is_flash_attn_3_available, is_flash_attn_greater_or_equal


if is_flash_attn_2_available():
Expand All @@ -29,6 +29,11 @@

_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)

if is_flash_attn_3_available():
from flash_attn_interface import _flash_attn_forward as _flash_attn_3_forward
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_3_func


def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""
Expand Down Expand Up @@ -194,6 +199,7 @@ def _flash_attention_forward(
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = None,
use_flash_attn_3: bool = False,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
Expand All @@ -219,18 +225,34 @@ def _flash_attention_forward(
Softcap for the attention logits, used e.g. in gemma2.
deterministic (`bool`, *optional*):
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
use_flash_attn_3 (`bool`, *optional*):
Determines if Flash Attention v3 should be used.
"""
if not use_top_left_mask:
causal = is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__.
causal = is_causal and query_length != 1

# FAv3 FP8 path uses `_flash_attn_3_forward` which doesn't set the default value.
softmax_scale = softmax_scale or query_states.shape[-1] ** (-0.5)

use_fp8 = use_flash_attn_3 and os.environ.get("FLASH_ATTENTION_3_FP8", "0") == "1"

flash_kwargs = {}
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
use_sliding_windows = (
_flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
)
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
use_sliding_windows = sliding_window is not None and key_states.shape[1] > sliding_window
if not _flash_supports_window_size:
flash_kwargs = {}
elif use_sliding_windows:
flash_kwargs["window_size"] = (sliding_window, sliding_window)
else:
# Needs default values for FP8 `_flash_attn_forward` path.
# Can be removed when FP8 bwd is supported and FP8 path uses `flash_attn_func`.
flash_kwargs["window_size"] = (-1, -1)

if not use_flash_attn_3:
flash_kwargs["dropout_p"] = dropout

if is_flash_attn_greater_or_equal("2.4.1"):
if deterministic is None:
Expand All @@ -249,19 +271,22 @@ def _flash_attention_forward(
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output_unpad = flash_attn_varlen_func(
func = flash_attn_varlen_3_func if use_flash_attn_3 else flash_attn_varlen_func

attn_output_unpad = func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
if use_flash_attn_3:
attn_output_unpad = attn_output_unpad[0]
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)

# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
Expand All @@ -276,25 +301,43 @@ def _flash_attention_forward(
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens

attn_output = flash_attn_varlen_func(
func = flash_attn_varlen_3_func if use_flash_attn_3 else flash_attn_varlen_func

attn_output = func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
if use_flash_attn_3:
attn_output = attn_output[0]

attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))

else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
)
if use_fp8:
# NOTE: uses `_flash_attn_forward` instead of `flash_attn_func` because no bwd for FP8 yet.
# `deterministic` is part of `flash_attn_func`/`FlashAttnFunc`, used in bwd, so not present here.
attn_output = _flash_attn_3_forward(
query_states.to(torch.float8_e4m3fn),
key_states.to(torch.float8_e4m3fn),
value_states.to(torch.float8_e4m3fn),
softmax_scale=softmax_scale,
causal=causal,
window_size=flash_kwargs["window_size"],
)[0]
else:
func = flash_attn_3_func if use_flash_attn_3 else flash_attn_func
attn_output = func(
query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
)
if use_flash_attn_3:
attn_output = attn_output[0]

return attn_output
Loading