diff --git a/.github/workflows/push-important-models.yml b/.github/workflows/push-important-models.yml
index 41bcd43fcc6fc2..17c1b6c86fb066 100644
--- a/.github/workflows/push-important-models.yml
+++ b/.github/workflows/push-important-models.yml
@@ -87,6 +87,11 @@ jobs:
run:
pytest -rsfE -m "flash_attn_test" --make-reports=${{ matrix.model-name }}_fa2_tests/ tests/${{ matrix.model-name }}/test_modeling_*
+ - name: Run FA3 tests
+ id: run_fa3_tests
+ run:
+ pytest -rsfE -m "flash_attn_3_test" --make-reports=${{ matrix.model-name }}_fa3_tests/ tests/${{ matrix.model-name }}/test_modeling_*
+
- name: "Test suite reports artifacts: ${{ matrix.model-name }}_fa2_tests"
if: ${{ always() }}
uses: actions/upload-artifact@v4
diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md
index 16be638498dfd4..4357e96735cd6d 100644
--- a/docs/source/en/llm_optims.md
+++ b/docs/source/en/llm_optims.md
@@ -348,6 +348,24 @@ model = AutoModelForCausalLM.from_pretrained(
)
```
+### FlashAttention-3
+
+FlashAttention and [FlashAttention-3](./perf_infer_gpu_one#flashattention-3) break up the attention computation into smaller chunks and reduces the number of intermediate read/write operations to GPU memory to speed up inference. FlashAttention-3 improves on FlashAttention-2 algorithm by taking advantage of new features on Hopper GPUs to maximize performance.
+
+To use FlashAttention-3, set `attn_implementation="flash_attention_3"` in the [`~PreTrainedModel.from_pretrained`] method.
+
+```py
+from transformers import AutoModelForCausalLM, BitsAndBytesConfig
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+model = AutoModelForCausalLM.from_pretrained(
+ "google/gemma-2b",
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_3",
+)
+```
+
### PyTorch scaled dot product attention
Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation.
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index dd3433f2cd4862..804986332c601c 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -198,6 +198,140 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
+
+## FlashAttention-3
+
+
+
+FlashAttention-3 is experimental and may change considerably in future versions.
+
+
+
+[FlashAttention-3](https://huggingface.co/papers/2407.08608) improves on FlashAttention-2 algorithm by taking advantage of new features on Hopper GPUs to maximize performance:
+
+1. overlap overall computation and data movement via warp-specialization
+2. interleave block-wise matmul and softmax operations
+3. block quantization and incoherent processing that leverages hardware support for FP8 low-precision
+
+FlashAttention-3 is currently supported for the following architectures:
+* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
+* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
+* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
+* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
+* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
+* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
+* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
+* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
+* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
+* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
+* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
+* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
+* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
+* [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel)
+* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
+* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
+* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
+* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
+* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
+* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
+* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
+* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
+* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
+* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
+* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
+* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava)
+* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)
+* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
+* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
+* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
+* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
+* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
+* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
+* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
+* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
+* [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel)
+* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
+* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
+* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
+* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
+* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
+* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
+* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
+* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
+* [Qwen2VL](https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel)
+* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
+* [Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
+* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
+* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
+* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
+* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
+* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
+* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
+
+You can request to add FlashAttention-3 support for another model by opening a GitHub Issue or Pull Request.
+
+Before you begin, make sure you have FlashAttention-3 installed.
+
+
+
+
+```bash
+git clone https://github.com/Dao-AILab/flash-attention
+cd flash-attention/hopper
+python setup.py install
+```
+
+
+
+
+To enable FlashAttention-3, pass the argument `attn_implementation="flash_attention_3"` to [`~AutoModelForCausalLM.from_pretrained`]:
+
+```python
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
+
+model_id = "tiiuae/falcon-7b"
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_3",
+)
+```
+
+
+
+FlashAttention-3 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-3.
+
+
+
+
+
+FlashAttention-3 can be combined with other optimization techniques like quantization to further speedup inference. For example, you can combine FlashAttention-3 with 8-bit or 4-bit quantization:
+
+```py
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
+
+model_id = "tiiuae/falcon-7b"
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+# load in 8bit
+model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ load_in_8bit=True,
+ attn_implementation="flash_attention_3",
+)
+
+# load in 4bit
+model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ load_in_4bit=True,
+ attn_implementation="flash_attention_3",
+)
+```
+
## PyTorch scaled dot product attention
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
diff --git a/pyproject.toml b/pyproject.toml
index bf78e0174394f5..fdcccd96884caa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -49,6 +49,7 @@ addopts = "--doctest-glob='**/*.md'"
doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
markers = [
"flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
+ "flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')",
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
"generate: marks tests that use the GenerationTesterMixin"
]
diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py
index 36775d8454ab8c..ef735b95c5ef99 100755
--- a/src/transformers/__init__.py
+++ b/src/transformers/__init__.py
@@ -1332,6 +1332,7 @@
"convert_and_export_with_cache",
]
+ _import_structure["modeling_flash_attention_3_utils"] = []
_import_structure["modeling_flash_attention_utils"] = []
_import_structure["modeling_outputs"] = []
_import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"]
diff --git a/src/transformers/modeling_flash_attention_3_utils.py b/src/transformers/modeling_flash_attention_3_utils.py
new file mode 100644
index 00000000000000..0a3faf784a06af
--- /dev/null
+++ b/src/transformers/modeling_flash_attention_3_utils.py
@@ -0,0 +1,295 @@
+# coding=utf-8
+# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Optional, Tuple
+
+import torch
+import torch.nn.functional as F
+
+from .utils import is_flash_attn_3_available
+
+
+if is_flash_attn_3_available():
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
+ from flash_attn_interface import _flash_attn_forward, flash_attn_func, flash_attn_varlen_func
+
+
+def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
+ """
+ Retrieves indexing data required to repad unpadded (ragged) tensors.
+
+ Arguments:
+ attention_mask (`torch.Tensor`):
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+
+ Return:
+ indices (`torch.Tensor`):
+ The indices of non-masked tokens from the flattened input sequence.
+ cu_seqlens (`torch.Tensor`):
+ The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ max_seqlen_in_batch (`int`):
+ Maximum sequence length in batch.
+ """
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
+ return (
+ indices,
+ cu_seqlens,
+ max_seqlen_in_batch,
+ )
+
+
+def _upad_input(
+ query_layer: torch.Tensor,
+ key_layer: torch.Tensor,
+ value_layer: torch.Tensor,
+ attention_mask: torch.Tensor,
+ query_length: int,
+):
+ """
+ Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches.
+
+ This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary
+ tensors for query, key, value tensors.
+
+ Arguments:
+ query_layer (`torch.Tensor`):
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
+ key_layer (`torch.Tensor`):
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+ value_layer (`torch.Tensor`):
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+ attention_mask (`torch.Tensor`):
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+ query_length (`int`):
+ Target length.
+
+ Return:
+ query_layer (`torch.Tensor`):
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
+ key_layer (`torch.Tensor`):
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+ value_layer (`torch.Tensor`):
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+ indices_q (`torch.Tensor`):
+ The indices of non-masked tokens from the flattened input target sequence.
+ (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+ """
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
+
+ key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k)
+ value_layer = index_first_axis(
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
+ )
+ if query_length == kv_seq_len:
+ query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k)
+ cu_seqlens_q = cu_seqlens_k
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
+ indices_q = indices_k
+ elif query_length == 1:
+ max_seqlen_in_batch_q = 1
+ cu_seqlens_q = torch.arange(
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
+ ) # There is a memcpy here, that is very bad.
+ indices_q = cu_seqlens_q[:-1]
+ query_layer = query_layer.squeeze(1)
+ else:
+ # The -q_len: slice assumes left padding.
+ attention_mask = attention_mask[:, -query_length:]
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
+
+ return (
+ query_layer,
+ key_layer,
+ value_layer,
+ indices_q,
+ (cu_seqlens_q, cu_seqlens_k),
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
+ )
+
+
+def prepare_fa2_from_position_ids(query, key, value, position_ids):
+ """
+ This function returns necessary arguments to call `flash_attn_varlen_func`.
+ All three query, key, value states will be flattened.
+ Cummulative lengths of each examples in the batch will be extracted from position_ids.
+
+ NOTE: ideally cummulative lengths should be prepared at the data collator stage
+
+ Arguments:
+ query (`torch.Tensor`):
+ Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim).
+ key (`torch.Tensor`):
+ Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+ value (`torch.Tensor`):
+ Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim).
+ position_ids (`torch.Tensor`):
+ Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid.
+
+ Return:
+ query (`torch.Tensor`):
+ Query state without padding. Shape: (total_target_length, num_heads, head_dim).
+ key (`torch.Tensor`):
+ Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+ value (`torch.Tensor`):
+ Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim).
+ indices_q (`torch.Tensor`):
+ The indices of non-masked tokens from the flattened input target sequence.
+ (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`):
+ The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,).
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`):
+ Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value).
+ """
+ query = query.view(-1, query.size(-2), query.size(-1))
+ key = key.view(-1, key.size(-2), key.size(-1))
+ value = value.view(-1, value.size(-2), value.size(-1))
+ position_ids = position_ids.flatten()
+ indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
+
+ cu_seq_lens = torch.cat(
+ (
+ indices_q[position_ids == 0],
+ torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
+ )
+ )
+
+ max_length = position_ids.max() + 1
+
+ return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length))
+
+
+def _flash_attention_3_forward(
+ query_states: torch.Tensor,
+ key_states: torch.Tensor,
+ value_states: torch.Tensor,
+ attention_mask: torch.Tensor,
+ query_length: int,
+ is_causal: bool,
+ position_ids: Optional[torch.Tensor] = None,
+ softmax_scale: Optional[float] = None,
+ use_top_left_mask: bool = False,
+ deterministic: bool = None,
+ descale: float = 1.0,
+ use_fp8: bool = False,
+):
+ """
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
+ first unpad the input, then computes the attention scores and pad the final attention scores.
+
+ Args:
+ query_states (`torch.Tensor`):
+ Input query states to be passed to Flash Attention API
+ key_states (`torch.Tensor`):
+ Input key states to be passed to Flash Attention API
+ value_states (`torch.Tensor`):
+ Input value states to be passed to Flash Attention API
+ attention_mask (`torch.Tensor`):
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
+ position of padding tokens and 1 for the position of non-padding tokens.
+ softmax_scale (`float`, *optional*):
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
+ use_top_left_mask (`bool`, defaults to `False`):
+ flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference.
+ deterministic (`bool`, *optional*):
+ Determines if the deterministic option enabled.
+ """
+ use_fp8 = os.environ.get("FLASH_ATTENTION_3_FP8", "0") == "1"
+
+ softmax_scale = softmax_scale or query_states.shape[-1] ** (-0.5)
+
+ if not use_top_left_mask:
+ causal = is_causal
+ else:
+ causal = is_causal and query_length != 1
+
+ flash_kwargs = {}
+
+ if deterministic is None:
+ deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
+ flash_kwargs["deterministic"] = deterministic
+
+ # Contains at least one padding token in the sequence
+ if attention_mask is not None:
+ batch_size = query_states.shape[0]
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input(
+ query_states, key_states, value_states, attention_mask, query_length
+ )
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output_unpad, _ = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ **flash_kwargs,
+ )
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
+
+ # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
+ # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
+ # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
+ elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1:
+ batch_size = query_states.size(0)
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(
+ query_states, key_states, value_states, position_ids
+ )
+
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
+
+ attn_output, _ = flash_attn_varlen_func(
+ query_states,
+ key_states,
+ value_states,
+ cu_seqlens_q=cu_seqlens_q,
+ cu_seqlens_k=cu_seqlens_k,
+ max_seqlen_q=max_seqlen_in_batch_q,
+ max_seqlen_k=max_seqlen_in_batch_k,
+ softmax_scale=softmax_scale,
+ causal=causal,
+ **flash_kwargs,
+ )
+
+ attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
+
+ else:
+ if use_fp8:
+ # NOTE: descale?
+ attn_output = _flash_attn_forward(
+ query_states.to(torch.float8_e4m3fn),
+ key_states.to(torch.float8_e4m3fn),
+ value_states.to(torch.float8_e4m3fn),
+ softmax_scale=softmax_scale,
+ causal=causal,
+ )[0]
+ else:
+ attn_output, _ = flash_attn_func(
+ query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
+ )
+
+ return attn_output
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index d4069766636053..314e3ee22cafdf 100755
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -82,6 +82,7 @@
is_accelerate_available,
is_bitsandbytes_available,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_offline_mode,
is_optimum_available,
is_peft_available,
@@ -1367,6 +1368,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Flash Attention 2 support
_supports_flash_attn_2 = False
+ # Flash Attention 3 support
+ _supports_flash_attn_3 = False
+
# SDPA support
_supports_sdpa = False
@@ -1550,10 +1554,12 @@ def _autoset_attn_implementation(
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
)
- if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]:
+ if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2", "flash_attention_3"]:
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
+ if cls._supports_flash_attn_3:
+ message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
raise ValueError(message + ".")
@@ -1575,6 +1581,14 @@ def _autoset_attn_implementation(
hard_check_only=False,
check_device_map=check_device_map,
)
+ elif config._attn_implementation == "flash_attention_3":
+ cls._check_and_enable_flash_attn_3(
+ config,
+ torch_dtype=torch_dtype,
+ device_map=device_map,
+ hard_check_only=False,
+ check_device_map=check_device_map,
+ )
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(
@@ -1734,6 +1748,83 @@ def _check_and_enable_flash_attn_2(
config._attn_implementation = "flash_attention_2"
return config
+ @classmethod
+ def _check_and_enable_flash_attn_3(
+ cls,
+ config,
+ torch_dtype: Optional[torch.dtype] = None,
+ device_map: Optional[Union[str, Dict[str, int]]] = None,
+ check_device_map: bool = True,
+ hard_check_only: bool = False,
+ ) -> PretrainedConfig:
+ """
+ Checks the availability of Flash Attention 3 and compatibility with the current model.
+
+ If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module.
+ """
+ if not cls._supports_flash_attn_3:
+ raise ValueError(
+ f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where"
+ f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
+ " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
+ )
+
+ if not is_flash_attn_3_available():
+ preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:"
+ # TODO: docs
+ install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-3 to install Flash Attention 3."
+
+ if importlib.util.find_spec("flash_attn_interface") is None:
+ raise ImportError(
+ f"{preface} the package flash_attn_interface seems to be not installed. {install_message}"
+ )
+
+ _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
+
+ if _is_bettertransformer:
+ raise ValueError(
+ "Flash Attention 3 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
+ )
+
+ if torch_dtype is None:
+ logger.warning_once(
+ "You are attempting to use Flash Attention 3.0 without specifying a torch dtype. This might lead to unexpected behaviour"
+ )
+ elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
+ logger.warning_once(
+ "Flash Attention 3.0 only supports torch.float16 and torch.bfloat16 dtypes, but"
+ f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
+ ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`'
+ )
+
+ # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
+ # or the model may be initialized under the context manager `with torch.device("cuda"):`.
+ if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
+ if torch.cuda.is_available():
+ logger.warning_once(
+ "You are attempting to use Flash Attention 3.0 with a model not initialized on GPU. Make sure to move the model to GPU"
+ " after initializing it on CPU with `model.to('cuda')`."
+ )
+ else:
+ raise ValueError(
+ "You are attempting to use Flash Attention 3.0 with a model not initialized on GPU and with no GPU available. "
+ "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
+ "or initialising the model on CPU and then moving it to GPU."
+ )
+ elif (
+ check_device_map
+ and device_map is not None
+ and isinstance(device_map, dict)
+ and ("cpu" in device_map.values() or "disk" in device_map.values())
+ ):
+ raise ValueError(
+ "You are attempting to use Flash Attention 3.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
+ "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
+ )
+ if not hard_check_only:
+ config._attn_implementation = "flash_attention_3"
+ return config
+
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
"""
diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py
index 5aad7b23a8a672..f6fd3e911d95d0 100644
--- a/src/transformers/models/bark/modeling_bark.py
+++ b/src/transformers/models/bark/modeling_bark.py
@@ -35,6 +35,7 @@
add_start_docstrings_to_model_forward,
is_accelerate_available,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
@@ -56,6 +57,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -279,9 +283,92 @@ def forward(
return outputs
+class BarkSelfFlashAttention3(BarkSelfAttention):
+ """
+ Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def _split_heads(self, tensor, num_heads, attn_head_size):
+ """
+ Splits hidden_size dim into attn_head_size and num_heads
+ """
+ new_shape = tensor.size()[:-1] + (num_heads, attn_head_size)
+ tensor = tensor.view(new_shape)
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features)
+ return tensor
+
+ def _merge_heads(self, tensor, num_heads, attn_head_size):
+ """
+ Merges attn_head_size dim and num_attn_heads dim into hidden_size
+ """
+ # re-assemble all head outputs side by side
+ # (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size)
+ tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,))
+ return tensor
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ past_key_values=None,
+ head_mask=None,
+ use_cache=False,
+ output_attentions=False,
+ ):
+ batch_size, query_len, _ = hidden_states.size()
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2)
+
+ query = self._split_heads(query, self.num_heads, self.head_dim)
+ key = self._split_heads(key, self.num_heads, self.head_dim)
+ value = self._split_heads(value, self.num_heads, self.head_dim)
+
+ if past_key_values is not None:
+ # (batch, head, seq_length, head_features) -> (batch, seq_length, head, head_features)
+ past_key = past_key_values[0].transpose(1, 2)
+ past_value = past_key_values[1].transpose(1, 2)
+ # and merge on seq_length
+ key = torch.cat((past_key, key), dim=1)
+ value = torch.cat((past_value, value), dim=1)
+
+ if use_cache is True:
+ # (batch, head, seq_length, head_features)
+ present = (key.transpose(1, 2), value.transpose(1, 2))
+ else:
+ present = None
+
+ attn_output = _flash_attention_3_forward(
+ query,
+ key,
+ value,
+ attention_mask,
+ query_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
+ attn_output = self.out_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+
+ outputs = (attn_output, present)
+ if output_attentions:
+ attn_weights = None
+ outputs += (attn_weights,)
+
+ return outputs
+
+
BARK_ATTENTION_CLASSES = {
"eager": BarkSelfAttention,
"flash_attention_2": BarkSelfFlashAttention2,
+ "flash_attention_3": BarkSelfFlashAttention3,
}
@@ -376,6 +463,7 @@ class BarkPreTrainedModel(PreTrainedModel):
config_class = BarkConfig
supports_gradient_checkpointing = False
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
def _init_weights(self, module):
"""Initialize the weights."""
@@ -561,6 +649,7 @@ def __init__(self, config):
self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias)
@@ -702,7 +791,7 @@ def forward(
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
attention_mask = attention_mask if 0 in attention_mask else None
else:
attention_mask = attention_mask.view(batch_size, -1)
@@ -1157,6 +1246,7 @@ def __init__(self, config):
self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self.layernorm_final = nn.LayerNorm(config.hidden_size)
@@ -1341,7 +1431,7 @@ def forward(
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
@@ -1811,3 +1901,39 @@ def _check_and_enable_flash_attn_2(
config.coarse_acoustics_config._attn_implementation = config._attn_implementation
config.fine_acoustics_config._attn_implementation = config._attn_implementation
return config
+
+ @classmethod
+ def _check_and_enable_flash_attn_3(
+ cls,
+ config,
+ torch_dtype: Optional[torch.dtype] = None,
+ device_map: Optional[Union[str, Dict[str, int]]] = None,
+ hard_check_only: bool = False,
+ check_device_map: bool = False,
+ ):
+ """
+ `_check_and_enable_flash_attn_3` originally don't expand flash attention enabling to the model
+ sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention
+ if necessary.
+
+ If you don't know about Flash Attention, check out the official repository of flash attention:
+ https://github.com/Dao-AILab/flash-attention
+
+ For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
+ specific section of the documentation to learn more about it:
+ https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
+
+ The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
+ half precision and not ran on CPU.
+
+ If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_3" so that the model
+ can initialize the correct attention module
+ """
+ config = super()._check_and_enable_flash_attn_3(
+ config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map
+ )
+
+ config.semantic_config._attn_implementation = config._attn_implementation
+ config.coarse_acoustics_config._attn_implementation = config._attn_implementation
+ config.fine_acoustics_config._attn_implementation = config._attn_implementation
+ return config
diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py
index fa928d05caa89b..2e31469f9c3c49 100755
--- a/src/transformers/models/bart/modeling_bart.py
+++ b/src/transformers/models/bart/modeling_bart.py
@@ -47,6 +47,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -57,6 +58,8 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
logger = logging.get_logger(__name__)
@@ -413,6 +416,125 @@ def forward(
return attn_output, attn_weights, past_key_value
+class BartFlashAttention3(BartAttention):
+ """
+ Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # BartFlashAttention3 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("BartFlashAttention3 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class BartSdpaAttention(BartAttention):
def forward(
self,
@@ -523,6 +645,7 @@ def forward(
"eager": BartAttention,
"sdpa": BartSdpaAttention,
"flash_attention_2": BartFlashAttention2,
+ "flash_attention_3": BartFlashAttention3,
}
@@ -748,6 +871,7 @@ class BartPreTrainedModel(PreTrainedModel):
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
@@ -979,6 +1103,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
)
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
@@ -1067,7 +1192,7 @@ def forward(
# expand attention_mask
if attention_mask is not None:
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
attention_mask = attention_mask if 0 in attention_mask else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
@@ -1163,6 +1288,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
)
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
@@ -1283,7 +1409,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input)
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
@@ -1303,7 +1429,7 @@ def forward(
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py
index d32ab95f51c7da..bdc86c34fb6bec 100644
--- a/src/transformers/models/bloom/modeling_bloom.py
+++ b/src/transformers/models/bloom/modeling_bloom.py
@@ -790,7 +790,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py
index 23334311ca9511..d1c33df2b42fd7 100644
--- a/src/transformers/models/chameleon/modeling_chameleon.py
+++ b/src/transformers/models/chameleon/modeling_chameleon.py
@@ -27,6 +27,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
@@ -538,6 +539,112 @@ def forward(
return attn_output, attn_weights, past_key_value
+class ChameleonFlashAttention3(ChameleonAttention):
+ """
+ Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
+ query_states = self.q_norm(query_states)
+
+ key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
+ key_states = self.k_norm(key_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim].
+ # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (ChameleonRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class ChameleonSdpaAttention(ChameleonAttention):
"""
Chameleon attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -634,6 +741,7 @@ def forward(
CHAMELEON_ATTENTION_CLASSES = {
"eager": ChameleonAttention,
"flash_attention_2": ChameleonFlashAttention2,
+ "flash_attention_3": ChameleonFlashAttention3,
"sdpa": ChameleonSdpaAttention,
}
@@ -1155,6 +1263,7 @@ class ChameleonPreTrainedModel(PreTrainedModel):
_no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_cache_class = True
@@ -1433,7 +1542,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py
index 64eb027e9e220c..3d120d3fb2b119 100644
--- a/src/transformers/models/clip/modeling_clip.py
+++ b/src/transformers/models/clip/modeling_clip.py
@@ -33,6 +33,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -43,6 +44,8 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
logger = logging.get_logger(__name__)
@@ -432,6 +435,82 @@ def forward(
return attn_output, attn_weights
+class CLIPFlashAttention3(CLIPAttention):
+ """
+ CLIPAttention flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ causal_attention_mask: Optional[torch.Tensor] = None,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ output_attentions = False
+
+ batch_size, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim)
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim)
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32.
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=causal_attention_mask is not None,
+ )
+
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
class CLIPSdpaAttention(CLIPAttention):
"""
SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -509,6 +588,7 @@ def forward(
"eager": CLIPAttention,
"sdpa": CLIPSdpaAttention,
"flash_attention_2": CLIPFlashAttention2,
+ "flash_attention_3": CLIPFlashAttention3,
}
@@ -588,6 +668,7 @@ class CLIPPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_sdpa = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
def _init_weights(self, module):
"""Initialize the weights"""
@@ -857,6 +938,7 @@ def __init__(self, config: CLIPTextConfig):
# For attention mask, it differs between `flash_attention_2` and other attention implementations
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
@@ -894,7 +976,7 @@ def forward(
)
# expand attention_mask
- if attention_mask is not None and not self._use_flash_attention_2:
+ if attention_mask is not None and not self._use_flash_attention_2 and not self._use_flash_attention_3:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py
index 46eea43e1285f8..2b3a953601a682 100644
--- a/src/transformers/models/codegen/modeling_codegen.py
+++ b/src/transformers/models/codegen/modeling_codegen.py
@@ -633,7 +633,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py
index ae84a9ec2d1a43..c8cb6103f86f39 100644
--- a/src/transformers/models/cohere/modeling_cohere.py
+++ b/src/transformers/models/cohere/modeling_cohere.py
@@ -44,6 +44,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
@@ -55,6 +56,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -539,6 +543,118 @@ def forward(
return attn_output, attn_weights, past_key_value
+class CohereFlashAttention3(CohereAttention):
+ """
+ Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim)
+ if self.use_qk_norm:
+ query_states = self.q_norm(query_states)
+ key_states = self.k_norm(key_states)
+
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (CohereLayerNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class CohereSdpaAttention(CohereAttention):
"""
Cohere attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -645,6 +761,7 @@ def forward(
COHERE_ATTENTION_CLASSES = {
"eager": CohereAttention,
"flash_attention_2": CohereFlashAttention2,
+ "flash_attention_3": CohereFlashAttention3,
"sdpa": CohereSdpaAttention,
}
@@ -751,6 +868,7 @@ class CoherePreTrainedModel(PreTrainedModel):
_no_split_modules = ["CohereDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -1005,7 +1123,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py
index dd2a676b26c27f..8b45b788a9a253 100755
--- a/src/transformers/models/data2vec/modeling_data2vec_audio.py
+++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py
@@ -40,6 +40,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_peft_available,
logging,
@@ -50,6 +51,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -608,6 +612,125 @@ def forward(
return attn_output, attn_weights, past_key_value
+class Data2VecAudioFlashAttention3(Data2VecAudioAttention):
+ """
+ Data2VecAudio flash attention module. This module inherits from `Data2VecAudioAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # Data2VecAudioFlashAttention3 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("Data2VecAudioFlashAttention3 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class Data2VecAudioSdpaAttention(Data2VecAudioAttention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Data2VecAudio
def forward(
@@ -719,6 +842,7 @@ def forward(
"eager": Data2VecAudioAttention,
"sdpa": Data2VecAudioSdpaAttention,
"flash_attention_2": Data2VecAudioFlashAttention2,
+ "flash_attention_3": Data2VecAudioFlashAttention3,
}
@@ -794,6 +918,7 @@ def __init__(self, config):
self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -810,7 +935,7 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -931,6 +1056,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py
index 43bac44ba1be20..215a6b0862f493 100644
--- a/src/transformers/models/dbrx/modeling_dbrx.py
+++ b/src/transformers/models/dbrx/modeling_dbrx.py
@@ -30,6 +30,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -40,6 +41,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DbrxConfig"
@@ -484,6 +488,117 @@ def forward(
return attn_output, attn_weights, past_key_value
+class DbrxFlashAttention3(DbrxAttention):
+ """Dbrx flash attention module.
+
+ This module inherits from `DbrxAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it
+ calls the public API of flash attention.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs: Any,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+ logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.")
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv_states = self.Wqkv(hidden_states)
+ if self.clip_qkv is not None:
+ qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv)
+
+ query_states, key_states, value_states = qkv_states.split(
+ [
+ self.hidden_size,
+ self.num_key_value_heads * self.head_dim,
+ self.num_key_value_heads * self.head_dim,
+ ],
+ dim=2,
+ )
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout
+ # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = query_states.dtype
+
+ logger.warning_once(
+ "The input hidden states seems to be silently casted in float32, this might be "
+ + "related to the fact you have upcasted embedding or layer norm layers in "
+ + f"float32. We will cast back the input in {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class DbrxSdpaAttention(DbrxAttention):
"""
Dbrx attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -582,6 +697,7 @@ def forward(
DBRX_ATTENTION_CLASSES = {
"eager": DbrxAttention,
"flash_attention_2": DbrxFlashAttention2,
+ "flash_attention_3": DbrxFlashAttention3,
"sdpa": DbrxSdpaAttention,
}
@@ -882,6 +998,7 @@ class DbrxPreTrainedModel(PreTrainedModel):
_no_split_modules = ["DbrxBlock"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -1164,7 +1281,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py
index e80e3c41d22cb6..a91cd4f8cb6ea1 100755
--- a/src/transformers/models/distilbert/modeling_distilbert.py
+++ b/src/transformers/models/distilbert/modeling_distilbert.py
@@ -44,6 +44,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -54,6 +55,8 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "distilbert-base-uncased"
@@ -329,6 +332,93 @@ def reshape(x: torch.Tensor) -> torch.Tensor:
return (attn_output,)
+class DistilBertFlashAttention3(MultiHeadSelfAttention):
+ """
+ DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module
+ stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
+ API of flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ mask: torch.Tensor,
+ head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, ...]:
+ """
+ Parameters:
+ query: torch.tensor(bs, seq_length, dim)
+ key: torch.tensor(bs, seq_length, dim)
+ value: torch.tensor(bs, seq_length, dim)
+ mask: torch.tensor(bs, seq_length)
+
+ Returns:
+ weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs,
+ seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True`
+ """
+ batch_size, q_length, dim = query.size()
+
+ dim_per_head = self.dim // self.n_heads
+
+ def reshape(x: torch.Tensor) -> torch.Tensor:
+ """separate heads"""
+ return x.view(batch_size, -1, self.n_heads, dim_per_head)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ query_states = reshape(self.q_lin(query))
+ key_states = reshape(self.k_lin(key))
+ value_states = reshape(self.v_lin(value))
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ if query_states.dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_lin.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_weights = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ mask,
+ q_length,
+ is_causal=self.is_causal,
+ )
+
+ attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head)
+ attn_output = self.out_lin(attn_weights_reshaped)
+
+ if output_attentions:
+ return (attn_output, attn_weights)
+ else:
+ return (attn_output,)
+
+
class FFN(nn.Module):
def __init__(self, config: PretrainedConfig):
super().__init__()
@@ -353,6 +443,7 @@ def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:
DISTILBERT_ATTENTION_CLASSES = {
"eager": MultiHeadSelfAttention,
"flash_attention_2": DistilBertFlashAttention2,
+ "flash_attention_3": DistilBertFlashAttention3,
}
@@ -503,6 +594,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
base_model_prefix = "distilbert"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
def _init_weights(self, module: nn.Module):
"""Initialize the weights."""
@@ -589,6 +681,7 @@ def __init__(self, config: PretrainedConfig):
self.embeddings = Embeddings(config) # Embeddings
self.transformer = Transformer(config) # Encoder
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
# Initialize weights and apply final processing
self.post_init()
@@ -694,7 +787,7 @@ def forward(
embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
if attention_mask is None:
diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py
index 73e8806352ebbd..ac511ea145a58f 100644
--- a/src/transformers/models/falcon/modeling_falcon.py
+++ b/src/transformers/models/falcon/modeling_falcon.py
@@ -43,6 +43,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
@@ -55,6 +56,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -688,6 +692,110 @@ def forward(
return attn_output, layer_past, attn_weights
+class FalconFlashAttention3(FalconAttention):
+ """
+ Falcon flash attention module. This module inherits from `FalconAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ alibi: Optional[torch.Tensor],
+ attention_mask: torch.Tensor,
+ position_ids: Optional[torch.LongTensor] = None,
+ layer_past: Optional[Cache] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ use_cache: bool = False,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ):
+ fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
+ num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
+ # 3 x [batch_size, seq_length, num_heads, head_dim]
+ (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
+
+ batch_size, query_length, _, _ = query_layer.shape
+
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim)
+ key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim)
+
+ if alibi is None:
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_layer, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
+
+ if layer_past is not None:
+ cache_kwargs = {"cache_position": cache_position}
+ if alibi is None:
+ cache_kwargs.update({"sin": sin, "cos": cos})
+ key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_layer = query_layer.transpose(1, 2)
+ key_layer = key_layer.transpose(1, 2)
+ value_layer = value_layer.transpose(1, 2)
+
+ if alibi is not None:
+ raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_layer.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.query_key_value.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_layer = query_layer.to(target_dtype)
+ key_layer = key_layer.to(target_dtype)
+ value_layer = value_layer.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_layer,
+ key_layer,
+ value_layer,
+ attention_mask,
+ query_length,
+ position_ids=position_ids,
+ is_causal=self.is_causal,
+ )
+
+ attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
+ attn_output = self.dense(attn_weights)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, layer_past, attn_weights
+
+
class FalconMLP(nn.Module):
def __init__(self, config: FalconConfig):
super().__init__()
@@ -708,6 +816,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
"eager": FalconAttention,
"sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA
"flash_attention_2": FalconFlashAttention2,
+ "flash_attention_3": FalconFlashAttention3,
}
@@ -909,6 +1018,7 @@ class FalconPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["FalconDecoderLayer"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -971,6 +1081,7 @@ def __init__(self, config: FalconConfig):
# Transformer blocks
self.h = nn.ModuleList([FalconDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
# Final Layer Norm
@@ -1157,7 +1268,10 @@ def _update_causal_mask(
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/gemma/diff_gemma.py b/src/transformers/models/gemma/diff_gemma.py
index fcdb0f0b3d7d8f..53b656282eb46a 100644
--- a/src/transformers/models/gemma/diff_gemma.py
+++ b/src/transformers/models/gemma/diff_gemma.py
@@ -34,6 +34,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
+from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import CausalLMOutputWithPast
from ...pytorch_utils import ALL_LAYERNORM_LAYERS
@@ -441,6 +442,103 @@ def forward(
return attn_output, attn_weights, past_key_value
+# TODO felix: does this inheritance really work out in the end to GemmaFlashAttention2 inheriting form GemmaAttention?
+class GemmaFlashAttention3(LlamaFlashAttention2):
+ """
+ Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (GemmaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class GemmaModel(LlamaModel):
def forward(
self,
diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py
index b14e0a4b3d8ca5..b327f0dd31f2da 100644
--- a/src/transformers/models/gemma/modeling_gemma.py
+++ b/src/transformers/models/gemma/modeling_gemma.py
@@ -30,6 +30,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
@@ -467,6 +468,106 @@ def forward(
return attn_output, attn_weights, past_key_value
+class GemmaFlashAttention3(GemmaAttention):
+ """
+ Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (GemmaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class GemmaSdpaAttention(GemmaAttention):
"""
Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -558,6 +659,7 @@ def forward(
GEMMA_ATTENTION_CLASSES = {
"eager": GemmaAttention,
"flash_attention_2": GemmaFlashAttention2,
+ "flash_attention_3": GemmaFlashAttention3,
"sdpa": GemmaSdpaAttention,
}
@@ -665,6 +767,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GemmaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -932,7 +1035,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/gemma2/diff_gemma2.py b/src/transformers/models/gemma2/diff_gemma2.py
index 30f371a1b61267..235e43c1ee0539 100644
--- a/src/transformers/models/gemma2/diff_gemma2.py
+++ b/src/transformers/models/gemma2/diff_gemma2.py
@@ -34,12 +34,15 @@
from ...cache_utils import Cache
from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
-from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging
+from ...utils import is_flash_attn_2_available, is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -177,6 +180,101 @@ def forward(
return attn_output, attn_weights, past_key_value
+class Gemma2FlashAttention3(Gemma2Attention):
+ """
+ Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (Gemma2RMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ ########### ONLY DIFFERENCE IS WE USE SLIDING AND PASS THE SOFTMAX SCALING
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ softmax_scale=self.scaling,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class Gemma2SdpaAttention(Gemma2Attention):
"""
Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -439,7 +537,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py
index 1909ef78501559..5665ce23548f10 100644
--- a/src/transformers/models/gemma2/modeling_gemma2.py
+++ b/src/transformers/models/gemma2/modeling_gemma2.py
@@ -39,6 +39,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
@@ -51,6 +52,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -442,6 +446,110 @@ def forward(
return attn_output, attn_weights, past_key_value
+class Gemma2FlashAttention3(Gemma2Attention):
+ """
+ Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "sliding_window": self.sliding_window,
+ "cache_position": cache_position,
+ }
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ if attention_mask is not None:
+ seq_len = attention_mask.shape[1]
+ key_states = key_states[:, :, :seq_len]
+ value_states = value_states[:, :, :seq_len]
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (Gemma2RMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ softmax_scale=self.scaling,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class Gemma2SdpaAttention(Gemma2Attention):
"""
Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -535,6 +643,7 @@ def forward(
GEMMA2_ATTENTION_CLASSES = {
"eager": Gemma2Attention,
"flash_attention_2": Gemma2FlashAttention2,
+ "flash_attention_3": Gemma2FlashAttention3,
"sdpa": Gemma2SdpaAttention,
}
@@ -568,7 +677,10 @@ def forward(
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
else:
@@ -641,6 +753,7 @@ class Gemma2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Gemma2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = False
@@ -919,7 +1032,10 @@ def _update_causal_mask(
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
# as it doesn't cause dynamic control issues.
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
return attention_mask
dtype, device = input_tensor.dtype, input_tensor.device
@@ -1130,6 +1246,7 @@ def prepare_inputs_for_generation(
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
and not self.config._attn_implementation == "flash_attention_2"
+ and not self.config._attn_implementation == "flash_attention_3"
):
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py
index 8dfbfb9064444d..91e75a1d6f4ede 100644
--- a/src/transformers/models/gpt2/modeling_gpt2.py
+++ b/src/transformers/models/gpt2/modeling_gpt2.py
@@ -45,6 +45,7 @@
add_start_docstrings_to_model_forward,
get_torch_version,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -56,6 +57,8 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
logger = logging.get_logger(__name__)
@@ -454,6 +457,109 @@ def forward(
return outputs
+class GPT2FlashAttention3(GPT2Attention):
+ """
+ GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: Optional[Tuple[torch.FloatTensor]],
+ layer_past: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
+ bsz, _, _ = hidden_states.size()
+ if encoder_hidden_states is not None:
+ if not hasattr(self, "q_attn"):
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
+ )
+
+ query = self.q_attn(hidden_states)
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
+ attention_mask = encoder_attention_mask
+ else:
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
+
+ query = self._split_heads(query, self.num_heads, self.head_dim)
+ key = self._split_heads(key, self.num_heads, self.head_dim)
+ value = self._split_heads(value, self.num_heads, self.head_dim)
+
+ if layer_past is not None:
+ past_key = layer_past[0]
+ past_value = layer_past[1]
+ key = torch.cat((past_key, key), dim=-2)
+ value = torch.cat((past_value, value), dim=-2)
+
+ present = None
+ if use_cache is True:
+ present = (key, value)
+
+ query_length = query.shape[2]
+ tgt_len = key.shape[2]
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim)
+ key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
+ value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ if query.dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.c_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query = query.to(target_dtype)
+ key = key.to(target_dtype)
+ value = value.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query,
+ key,
+ value,
+ attention_mask,
+ query_length,
+ is_causal=self.is_causal,
+ )
+
+ attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
+ attn_output = self.c_proj(attn_weights_reshaped)
+ attn_output = self.resid_dropout(attn_output)
+
+ outputs = (attn_output, present)
+ if output_attentions:
+ outputs += (attn_weights_reshaped,)
+
+ return outputs
+
+
class GPT2SdpaAttention(GPT2Attention):
"""
GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -578,7 +684,12 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
return hidden_states
-GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention}
+GPT2_ATTENTION_CLASSES = {
+ "eager": GPT2Attention,
+ "flash_attention_2": GPT2FlashAttention2,
+ "flash_attention_3": GPT2FlashAttention3,
+ "sdpa": GPT2SdpaAttention,
+}
class GPT2Block(nn.Module):
@@ -674,6 +785,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPT2Block"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def __init__(self, *inputs, **kwargs):
@@ -1031,7 +1143,7 @@ def forward(
# Attention mask.
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
- if self._attn_implementation == "flash_attention_2":
+ if self._attn_implementation == "flash_attention_2" or self._attn_implementation == "flash_attention_3":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif _use_sdpa:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
@@ -1068,7 +1180,10 @@ def forward(
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
)
- elif not self._attn_implementation == "flash_attention_2":
+ elif (
+ not self._attn_implementation == "flash_attention_2"
+ and not self._attn_implementation == "flash_attention_3"
+ ):
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
index 0f927a72469dc9..8233a2a459e4c1 100644
--- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
+++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
@@ -36,6 +36,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
@@ -45,6 +46,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -396,6 +400,122 @@ def forward(
return outputs # a, present, (attentions)
+class GPTBigCodeFlashAttention3(GPTBigCodeAttention):
+ """
+ GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module
+ stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
+ API of flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ layer_past: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ ) -> Union[
+ Tuple[torch.Tensor, Optional[torch.Tensor]],
+ Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]],
+ ]:
+ if encoder_hidden_states is not None:
+ if not hasattr(self, "q_attn") or not self.is_cross_attention:
+ raise ValueError(
+ "If class is used as cross attention, the weights `q_attn` have to be defined. "
+ "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`."
+ )
+
+ query = self.q_attn(hidden_states)
+ key_value = self.c_attn(encoder_hidden_states)
+ attention_mask = encoder_attention_mask
+ elif self.multi_query:
+ query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2)
+ else:
+ # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim),
+ # i.e., the memory layout is not the same as GPT2.
+ # This makes the concatenation with past_key_value more efficient.
+ query, key_value = (
+ self.c_attn(hidden_states)
+ .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim)
+ .transpose(1, 2)
+ .split((self.head_dim, 2 * self.head_dim), dim=3)
+ )
+
+ if layer_past is not None:
+ key_value = torch.cat((layer_past, key_value), dim=-2)
+ present = key_value if use_cache else None
+
+ key, value = key_value.split((self.head_dim, self.head_dim), dim=-1)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ if self.multi_query:
+ batch_size, query_length, _ = query.shape
+ query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim)
+ key = key.unsqueeze(2)
+ value = value.unsqueeze(2)
+ else:
+ query_length = query.shape[2]
+ batch_size, _, tgt, _ = key.shape
+ query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim)
+ key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
+ value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.c_attn.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+ query = query.to(target_dtype)
+ key = key.to(target_dtype)
+ value = value.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query,
+ key,
+ value,
+ attention_mask,
+ query_length,
+ is_causal=self.is_causal,
+ )
+
+ attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
+ attn_output = self.c_proj(attn_weights_reshaped)
+ attn_output = self.resid_dropout(attn_output)
+
+ outputs = (attn_output, present)
+
+ if output_attentions:
+ if self.multi_query:
+ # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length)
+ attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2)
+ else:
+ attn_weights_reshaped = None
+
+ outputs += (attn_weights_reshaped,)
+
+ return outputs # a, present, (attentions)
+
+
class GPTBigCodeSdpaAttention(GPTBigCodeAttention):
def _attn(self, query, key, value, attention_mask=None, head_mask=None):
if head_mask is not None:
@@ -561,6 +681,7 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
GPTBIGCODE_ATTENTION_CLASSES = {
"eager": GPTBigCodeAttention,
"flash_attention_2": GPTBigCodeFlashAttention2,
+ "flash_attention_3": GPTBigCodeFlashAttention3,
"sdpa": GPTBigCodeSdpaAttention,
}
@@ -666,6 +787,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTBigCodeBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def __init__(self, *inputs, **kwargs):
@@ -810,6 +932,7 @@ def __init__(self, config):
self._use_sdpa = config._attn_implementation == "sdpa"
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
# Initialize weights and apply final processing
self.post_init()
@@ -891,7 +1014,7 @@ def forward(
key_length = past_length + query_length
self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None
encoder_attention_mask = (
diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
index 72590862b749f0..571d15ae4f3cd1 100755
--- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py
+++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
@@ -41,6 +41,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_fx_available,
logging,
@@ -51,6 +52,8 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
@@ -426,9 +429,102 @@ def forward(
return outputs
+class GPTNeoFlashAttention3(GPTNeoSelfAttention):
+ """
+ GPTNeo flash attention module. This module inherits from `GPTNeoSelfAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ layer_past=None,
+ head_mask=None,
+ use_cache=False,
+ output_attentions=False,
+ cache_position=None,
+ ):
+ bsz, _, _ = hidden_states.size()
+
+ query = self.q_proj(hidden_states)
+ key = self.k_proj(hidden_states)
+ value = self.v_proj(hidden_states)
+
+ query = self._split_heads(query, self.num_heads, self.head_dim)
+ key = self._split_heads(key, self.num_heads, self.head_dim)
+ value = self._split_heads(value, self.num_heads, self.head_dim)
+
+ if layer_past is not None:
+ cache_kwargs = {"cache_position": cache_position}
+ key, value = layer_past.update(key, value, self.layer_id, cache_kwargs)
+
+ query_length = query.shape[2]
+ tgt_len = key.shape[2]
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim)
+ key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
+ value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
+
+ if attention_mask is not None: # no matter the length, we just slice it
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ if query.dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query = query.to(target_dtype)
+ key = key.to(target_dtype)
+ value = value.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query,
+ key,
+ value,
+ attention_mask,
+ query_length,
+ softmax_scale=1.0,
+ is_causal=self.is_causal,
+ )
+
+ attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
+ attn_output = self.out_proj(attn_weights_reshaped)
+ attn_output = self.resid_dropout(attn_output)
+
+ outputs = (attn_output, layer_past)
+ if output_attentions:
+ outputs += (attn_weights_reshaped,)
+
+ return outputs
+
+
GPT_NEO_ATTENTION_CLASSES = {
"eager": GPTNeoSelfAttention,
"flash_attention_2": GPTNeoFlashAttention2,
+ "flash_attention_3": GPTNeoFlashAttention3,
}
@@ -550,6 +646,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTNeoBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = False # TODO: needs a HybridCache
@@ -847,7 +944,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
index e88302efa7bb04..d935ac65a3a722 100755
--- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py
+++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@@ -43,6 +43,7 @@
from ...utils import (
get_torch_version,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
@@ -52,6 +53,10 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM"
@@ -125,6 +130,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTNeoXLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
@@ -459,6 +465,100 @@ def forward(
return outputs
+class GPTNeoXFlashAttention3(GPTNeoXAttention):
+ """
+ GPTNeoX flash attention module. This module inherits from `GPTNeoXAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ attention_mask: torch.FloatTensor,
+ position_ids: torch.LongTensor,
+ head_mask: Optional[torch.FloatTensor] = None,
+ layer_past: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ):
+ # Apply attention-specific projections and rope
+ query, key, value, present = self._attn_projections_and_rope(
+ hidden_states=hidden_states,
+ position_ids=position_ids,
+ layer_past=layer_past,
+ use_cache=use_cache,
+ cache_position=cache_position,
+ position_embeddings=position_embeddings,
+ )
+
+ query_length = query.shape[-2]
+
+ # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision
+ target_dtype = value.dtype
+ if query.dtype != target_dtype:
+ query = query.to(target_dtype)
+ if key.dtype != target_dtype:
+ key = key.to(target_dtype)
+
+ # Permute to get the expected shape for Flash Attention
+ query = query.permute(0, 2, 1, 3)
+ key = key.permute(0, 2, 1, 3)
+ value = value.permute(0, 2, 1, 3)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 / bfloat16 just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ input_dtype = query.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.query_key_value.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query = query.to(target_dtype)
+ key = key.to(target_dtype)
+ value = value.to(target_dtype)
+
+ # Compute attention
+ attn_weights = _flash_attention_3_forward(
+ query,
+ key,
+ value,
+ attention_mask,
+ query_length,
+ softmax_scale=self.norm_factor,
+ is_causal=self.is_causal,
+ )
+
+ # Reshape outputs
+ attn_output = attn_weights.reshape(
+ attn_weights.shape[0], attn_weights.shape[1], self.num_attention_heads * self.head_size
+ )
+ attn_output = self.dense(attn_output)
+
+ outputs = (attn_output, layer_past)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
class GPTNeoXSdpaAttention(GPTNeoXAttention):
"""
GPTNeoX attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -727,6 +827,7 @@ def forward(self, hidden_states):
GPT_NEOX_ATTENTION_CLASSES = {
"eager": GPTNeoXAttention,
"flash_attention_2": GPTNeoXFlashAttention2,
+ "flash_attention_3": GPTNeoXFlashAttention3,
"sdpa": GPTNeoXSdpaAttention,
}
@@ -1044,7 +1145,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
index bf832195b4efc3..801a6f0d778141 100755
--- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
+++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
@@ -748,7 +748,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py
index bd7ce5696fa077..7a4f69e41d9bba 100644
--- a/src/transformers/models/gptj/modeling_gptj.py
+++ b/src/transformers/models/gptj/modeling_gptj.py
@@ -38,6 +38,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_torch_fx_proxy,
logging,
@@ -50,6 +51,10 @@
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "hf-internal-testing/tiny-random-gptj"
@@ -457,9 +462,145 @@ def forward(
return outputs
+class GPTJFlashAttention3(GPTJAttention):
+ """
+ GPTJ flash attention module. This module inherits from `GPTJAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ layer_past: Optional[Cache] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[
+ Tuple[torch.Tensor, Tuple[torch.Tensor]],
+ Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
+ ]:
+ query = self.q_proj(hidden_states)
+ key = self.k_proj(hidden_states)
+ value = self.v_proj(hidden_states)
+
+ query = self._split_heads(query, self.num_attention_heads, self.head_dim, True)
+ key = self._split_heads(key, self.num_attention_heads, self.head_dim, True)
+ value = self._split_heads(value, self.num_attention_heads, self.head_dim, False)
+
+ if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing():
+ # The logic to conditionally copy to GPU could not be traced, so we do this
+ # every time in the torch.fx case
+ embed_positions = get_embed_positions(self.embed_positions, position_ids)
+ else:
+ embed_positions = self._get_embed_positions(position_ids)
+
+ repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1])
+ sincos = torch.gather(embed_positions, 1, repeated_position_ids)
+ sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
+
+ if self.rotary_dim is not None:
+ k_rot = key[:, :, :, : self.rotary_dim]
+ k_pass = key[:, :, :, self.rotary_dim :]
+
+ q_rot = query[:, :, :, : self.rotary_dim]
+ q_pass = query[:, :, :, self.rotary_dim :]
+
+ k_rot = apply_rotary_pos_emb(k_rot, sin, cos)
+ q_rot = apply_rotary_pos_emb(q_rot, sin, cos)
+
+ key = torch.cat([k_rot, k_pass], dim=-1)
+ query = torch.cat([q_rot, q_pass], dim=-1)
+ else:
+ key = apply_rotary_pos_emb(key, sin, cos)
+ query = apply_rotary_pos_emb(query, sin, cos)
+
+ # tanspose to have the desired shape
+ # before transpose: batch_size x seq_length x num_attention_heads x head_dim
+ # after transpose: batch_size x num_attention_heads x seq_length x head_dim
+ key = key.permute(0, 2, 1, 3)
+ query = query.permute(0, 2, 1, 3)
+ # value: batch_size x num_attention_heads x seq_length x head_dim
+
+ if layer_past is not None:
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "partial_rotation_size": self.rotary_dim,
+ "cache_position": cache_position,
+ }
+ key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs)
+
+ # The Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we need to keep the original shape for query and key, and reshape value
+ # to have the correct shape.
+ key = key.permute(0, 2, 1, 3).contiguous()
+ query = query.permute(0, 2, 1, 3).contiguous()
+ value = value.permute(0, 2, 1, 3).contiguous()
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query = query.to(target_dtype)
+ key = key.to(target_dtype)
+ value = value.to(target_dtype)
+
+ query_length = query.shape[1]
+
+ # Compute attention
+ attn_weights = _flash_attention_3_forward(
+ query,
+ key,
+ value,
+ attention_mask,
+ query_length,
+ is_causal=self.is_causal,
+ )
+
+ # Reshape outputs
+ attn_output = attn_weights.reshape(
+ attn_weights.shape[0], attn_weights.shape[1], attn_weights.shape[2] * attn_weights.shape[3]
+ )
+ attn_output = self.out_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+
+ outputs = (attn_output, layer_past)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
+
+
GPTJ_ATTENTION_CLASSES = {
"eager": GPTJAttention,
"flash_attention_2": GPTJFlashAttention2,
+ "flash_attention_3": GPTJFlashAttention3,
}
@@ -540,6 +681,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTJBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
@@ -720,6 +862,7 @@ def __init__(self, config):
self.post_init()
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
@@ -942,7 +1085,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py
index 876f5ed2a7c8da..d8929470a912f6 100644
--- a/src/transformers/models/granite/modeling_granite.py
+++ b/src/transformers/models/granite/modeling_granite.py
@@ -23,6 +23,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
@@ -453,6 +454,102 @@ def forward(
return attn_output, attn_weights, past_key_value
+class GraniteFlashAttention3(GraniteAttention):
+ """
+ Granite flash attention module. This module inherits from `GraniteAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (GraniteRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ softmax_scale=self.scaling,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class GraniteSdpaAttention(GraniteAttention):
"""
Granite attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -547,6 +644,7 @@ def forward(
GRANITE_ATTENTION_CLASSES = {
"eager": GraniteAttention,
"flash_attention_2": GraniteFlashAttention2,
+ "flash_attention_3": GraniteFlashAttention3,
"sdpa": GraniteSdpaAttention,
}
@@ -662,6 +760,7 @@ class GranitePreTrainedModel(PreTrainedModel):
_no_split_modules = ["GraniteDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -934,7 +1033,10 @@ def _update_causal_mask(
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py
index da79c2894877b4..70ec00e809635d 100755
--- a/src/transformers/models/hubert/modeling_hubert.py
+++ b/src/transformers/models/hubert/modeling_hubert.py
@@ -32,6 +32,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -42,6 +43,8 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
logger = logging.get_logger(__name__)
@@ -678,6 +681,125 @@ def forward(
return attn_output, attn_weights, past_key_value
+class HubertFlashAttention3(HubertAttention):
+ """
+ Hubert flash attention module. This module inherits from `HubertAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # HubertFlashAttention3 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("HubertFlashAttention3 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class HubertSdpaAttention(HubertAttention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Hubert
def forward(
@@ -789,6 +911,7 @@ def forward(
"eager": HubertAttention,
"sdpa": HubertSdpaAttention,
"flash_attention_2": HubertFlashAttention2,
+ "flash_attention_3": HubertFlashAttention3,
}
@@ -936,6 +1059,7 @@ def __init__(self, config):
self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -952,7 +1076,7 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1024,6 +1148,7 @@ def __init__(self, config):
)
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -1040,7 +1165,7 @@ def forward(
# make sure padded tokens are not attended to
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1112,6 +1237,7 @@ class HubertPreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py
index 0d26c4fc65de78..714dfbc29053ac 100644
--- a/src/transformers/models/idefics/modeling_idefics.py
+++ b/src/transformers/models/idefics/modeling_idefics.py
@@ -1475,7 +1475,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py
index 6108f0e8a42e8f..1f529bc3bac866 100644
--- a/src/transformers/models/idefics2/modeling_idefics2.py
+++ b/src/transformers/models/idefics2/modeling_idefics2.py
@@ -32,6 +32,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
@@ -44,6 +45,8 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
logger = logging.get_logger(__name__)
@@ -363,9 +366,98 @@ def forward(
return attn_output, attn_weights
+class Idefics2VisionFlashAttention3(Idefics2VisionAttention):
+ """
+ Idefics2Vision flash attention module. This module inherits from `Idefics2VisionAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
+ key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (Idefics2VisionRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
IDEFICS_VISION_ATTENTION_CLASSES = {
"eager": Idefics2VisionAttention,
"flash_attention_2": Idefics2VisionFlashAttention2,
+ "flash_attention_3": Idefics2VisionFlashAttention3,
}
@@ -582,6 +674,7 @@ def __init__(self, config: Idefics2VisionConfig):
self.encoder = Idefics2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def get_input_embeddings(self):
return self.embeddings
@@ -623,7 +716,7 @@ def forward(
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
patch_attention_mask = None
- elif not self._use_flash_attention_2:
+ elif not self._use_flash_attention_2 and not self._use_flash_attention_3:
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
@@ -909,9 +1002,125 @@ def forward(
return attn_output, attn_weights, past_key_value
+class Idefics2PerceiverFlashAttention3(Idefics2PerceiverAttention):
+ """
+ Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ latents: torch.Tensor,
+ context: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ bsz, q_len, _ = latents.size()
+ kv_seq_len = q_len + context.size()[1]
+
+ # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn!
+ # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents`
+ query_states = self.q_proj(latents)
+ key_states = self.k_proj(torch.cat([context, latents], dim=-2))
+ value_states = self.v_proj(torch.cat([context, latents], dim=-2))
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
+ key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window:
+ slicing_tokens = kv_seq_len - self.config.sliding_window
+
+ past_key = past_key_value[0]
+ past_value = past_key_value[1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ "past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1,"
+ f" head_dim`), got {past_key.shape}"
+ )
+
+ past_key_value = (past_key, past_value)
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+
+ past_key_value = (key_states, value_states) if use_cache else None
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
IDEFICS2_PERCEIVER_ATTENTION_CLASSES = {
"eager": Idefics2PerceiverAttention,
"flash_attention_2": Idefics2PerceiverFlashAttention2,
+ "flash_attention_3": Idefics2PerceiverFlashAttention3,
}
@@ -1010,6 +1219,7 @@ def __init__(self, config) -> None:
self.norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -1025,7 +1235,7 @@ def forward(
attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
attention_mask = (
_prepare_4d_attention_mask(attention_mask, latents.dtype, tgt_len=self.n_latents)
- if not self._use_flash_attention_2
+ if not self._use_flash_attention_2 and not self._use_flash_attention_3
else attention_mask
)
@@ -1093,6 +1303,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
def _init_weights(self, module):
@@ -1227,6 +1438,7 @@ def __init__(self, config: Idefics2Config):
self.image_token_id = self.config.image_token_id
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self.post_init()
diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py
index 60e1670a3c2784..0725d4f058570a 100755
--- a/src/transformers/models/jamba/modeling_jamba.py
+++ b/src/transformers/models/jamba/modeling_jamba.py
@@ -48,6 +48,7 @@
from ...utils.import_utils import (
is_causal_conv1d_available,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_mamba_ssm_available,
is_torchdynamo_compiling,
@@ -58,6 +59,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
if is_mamba_ssm_available():
from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn
@@ -498,6 +502,119 @@ def forward(
return attn_output, attn_weights, past_key_value
+class JambaFlashAttention3(JambaAttention):
+ """
+ Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = cache_position[-1]
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = cache_position[0] > 0
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
# Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba
class JambaSdpaAttention(JambaAttention):
"""
@@ -584,6 +701,7 @@ def forward(
JAMBA_ATTENTION_CLASSES = {
"eager": JambaAttention,
"flash_attention_2": JambaFlashAttention2,
+ "flash_attention_3": JambaFlashAttention3,
"sdpa": JambaSdpaAttention,
}
@@ -1121,6 +1239,7 @@ class JambaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
_is_stateful = True
@@ -1377,7 +1496,10 @@ def forward(
)
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py
index c273b021d73664..ab2b46e8e6cb93 100644
--- a/src/transformers/models/jetmoe/modeling_jetmoe.py
+++ b/src/transformers/models/jetmoe/modeling_jetmoe.py
@@ -36,6 +36,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
@@ -47,6 +48,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "jetmoe"
@@ -804,9 +808,116 @@ def forward(
return attn_output, attn_weights, past_key_value, router_logits
+class JetMoeFlashAttention3(JetMoeAttention):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: Optional[torch.FloatTensor],
+ attention_mask: Optional[torch.FloatTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ use_cache: Optional[bool] = False,
+ output_attentions: Optional[bool] = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[
+ Tuple[torch.Tensor, Tuple[torch.Tensor]],
+ Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]],
+ ]:
+ """
+ Forward pass of the JetMoeAttention module.
+
+ Args:
+ hidden_states (Optional[torch.FloatTensor]): Input hidden states.
+ attention_mask (Optional[torch.FloatTensor]): Attention mask.
+ layer_past (Optional[Tuple[torch.Tensor]]): Past layer state.
+ use_cache (Optional[bool]): Whether to use cached states.
+ output_attentions (Optional[bool]): Whether to output attention weights.
+ cache_position (Optional[torch.LongTensor]): Position of the cache.
+
+ Returns:
+ Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[...]]]: Tuple containing outputs.
+ """
+ output_attentions = False
+ bsz, q_len, hidden_size = hidden_states.size()
+
+ # calculate query, key, values
+ query_states, router_logits, topo_info = self.experts.map(hidden_states)
+ key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads for top-k attention experts
+ key_states = key_states.repeat(1, self.top_k, 1, 1)
+ value_states = value_states.repeat(1, self.top_k, 1, 1)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.kv_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ ).to(input_dtype)
+
+ # output projection
+ attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size)
+ attn_output = self.experts.reduce(attn_output, topo_info)
+ attn_output = attn_output.view(bsz, q_len, hidden_size) # re-assemble all head outputs side by side
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value, router_logits
+
+
JETMOE_ATTENTION_CLASSES = {
"eager": JetMoeAttention,
"flash_attention_2": JetMoeFlashAttention2,
+ "flash_attention_3": JetMoeFlashAttention3,
"sdpa": JetMoeSdpaAttention,
}
@@ -879,6 +990,7 @@ class JetMoePreTrainedModel(PreTrainedModel):
_no_split_modules = ["JetMoeBlock"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
@@ -1048,7 +1160,11 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
- if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
+ if (
+ attention_mask is not None
+ and (self._attn_implementation == "flash_attention_2" or self._attn_implementation == "flash_attention_3")
+ and use_cache
+ ):
batch_size = inputs_embeds.shape[0]
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
@@ -1136,7 +1252,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index c7017832b9324c..24b02ec4ea7697 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -29,6 +29,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
@@ -569,6 +570,117 @@ def forward(
return attn_output, attn_weights, past_key_value
+class LlamaFlashAttention3(LlamaAttention):
+ """
+ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # TODO: get `use_fp8` to here, add attention_kwargs or something
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class LlamaSdpaAttention(LlamaAttention):
"""
Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -671,6 +783,7 @@ def forward(
LLAMA_ATTENTION_CLASSES = {
"eager": LlamaAttention,
"flash_attention_2": LlamaFlashAttention2,
+ "flash_attention_3": LlamaFlashAttention3,
"sdpa": LlamaSdpaAttention,
}
@@ -783,6 +896,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -1039,7 +1153,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py
index eb1c55341b0784..99682d7818fde1 100644
--- a/src/transformers/models/llava/modeling_llava.py
+++ b/src/transformers/models/llava/modeling_llava.py
@@ -125,6 +125,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
def _init_weights(self, module):
diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py
index c1d1ca8c276d7a..6d6bd7ae356ac3 100644
--- a/src/transformers/models/llava_next/modeling_llava_next.py
+++ b/src/transformers/models/llava_next/modeling_llava_next.py
@@ -228,6 +228,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaNextVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
def _init_weights(self, module):
diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py
index 589bf346ceeb9e..9c3f296b38b0ec 100644
--- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py
@@ -272,6 +272,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaNextVideoVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
def _init_weights(self, module):
diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py
index 697ea84fea5040..7c7ad78c86dbc6 100644
--- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py
+++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py
@@ -233,6 +233,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaOnevisionVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
_supports_static_cache = False # Qwen2 doesn't but llava has no reasons to not support
_supports_quantized_cache = True
diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py
index 23a855fff25672..057edcd61bcd3d 100755
--- a/src/transformers/models/m2m_100/modeling_m2m_100.py
+++ b/src/transformers/models/m2m_100/modeling_m2m_100.py
@@ -37,6 +37,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -47,6 +48,8 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
logger = logging.get_logger(__name__)
@@ -427,6 +430,92 @@ def forward(
return attn_output, None, past_key_value
+class M2M100FlashAttention3(M2M100Attention):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ softmax_scale=None,
+ is_causal=self.is_causal,
+ )
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, q_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, None, past_key_value
+
+
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100
class M2M100EncoderLayer(nn.Module):
def __init__(self, config: M2M100Config):
@@ -501,6 +590,7 @@ def forward(
M2M100_ATTENTION_CLASSES = {
"eager": M2M100Attention,
"flash_attention_2": M2M100FlashAttention2,
+ "flash_attention_3": M2M100FlashAttention3,
}
@@ -631,6 +721,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
def _init_weights(self, module):
std = self.config.init_std
@@ -804,6 +895,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] =
self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@@ -884,7 +976,7 @@ def forward(
# expand attention_mask
if attention_mask is not None:
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -980,6 +1072,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] =
)
self.layers = nn.ModuleList([M2M100DecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
@@ -1090,7 +1183,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
combined_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1101,7 +1194,7 @@ def forward(
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
else:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
@@ -1232,7 +1325,7 @@ def __init__(self, config: M2M100Config):
self.encoder = M2M100Encoder(config, self.shared)
self.decoder = M2M100Decoder(config, self.shared)
- if config._attn_implementation == "flash_attention_2":
+ if config._attn_implementation == "flash_attention_2" or config._attn_implementation == "flash_attention_3":
logger.warning_once(
"Attention with Flash Attention 2 does not support `layer_head_mask`. If you need this feature, please use standard attention."
)
diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py
index 9455f21b2073ff..7597ea1ad693f8 100755
--- a/src/transformers/models/mbart/modeling_mbart.py
+++ b/src/transformers/models/mbart/modeling_mbart.py
@@ -46,6 +46,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -56,6 +57,8 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
logger = logging.get_logger(__name__)
@@ -410,6 +413,125 @@ def forward(
return attn_output, attn_weights, past_key_value
+class MBartFlashAttention3(MBartAttention):
+ """
+ MBart flash attention module. This module inherits from `MBartAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # MBartFlashAttention3 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("MBartFlashAttention3 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MBart
class MBartSdpaAttention(MBartAttention):
def forward(
@@ -521,6 +643,7 @@ def forward(
"eager": MBartAttention,
"sdpa": MBartSdpaAttention,
"flash_attention_2": MBartFlashAttention2,
+ "flash_attention_3": MBartFlashAttention3,
}
@@ -745,6 +868,7 @@ class MBartPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
@@ -1043,7 +1167,10 @@ def forward(
# expand attention_mask
if attention_mask is not None:
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa" and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
@@ -1260,7 +1387,10 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self.config._attn_implementation == "sdpa" and not output_attentions and cross_attn_head_mask is None:
@@ -1280,7 +1410,10 @@ def forward(
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.config._attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py
index ffe16b27203301..c14438329827f5 100644
--- a/src/transformers/models/mistral/modeling_mistral.py
+++ b/src/transformers/models/mistral/modeling_mistral.py
@@ -41,6 +41,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
@@ -52,6 +53,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MistralConfig"
@@ -402,6 +406,131 @@ def forward(
return attn_output, attn_weights, past_key_value
+class MistralFlashAttention3(MistralAttention):
+ """
+ Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ):
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += cache_position[0]
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral
# TODO(joao): add me back asap :)
class MistralSdpaAttention(MistralAttention):
@@ -495,6 +624,7 @@ def forward(
MISTRAL_ATTENTION_CLASSES = {
"eager": MistralAttention,
"flash_attention_2": MistralFlashAttention2,
+ "flash_attention_3": MistralFlashAttention3,
"sdpa": MistralSdpaAttention,
}
@@ -604,6 +734,7 @@ class MistralPreTrainedModel(PreTrainedModel):
_no_split_modules = ["MistralDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
@@ -853,7 +984,7 @@ def _update_causal_mask(
use_cache: bool,
output_attentions: bool,
):
- if self._attn_implementation == "flash_attention_2":
+ if self._attn_implementation == "flash_attention_2" or self._attn_implementation == "flash_attention_3":
if attention_mask is not None and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py
index c7062e75b1085c..167054fbe61bd6 100644
--- a/src/transformers/models/mixtral/modeling_mixtral.py
+++ b/src/transformers/models/mixtral/modeling_mixtral.py
@@ -43,6 +43,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_torchdynamo_compiling,
logging,
replace_return_docstrings,
@@ -54,6 +55,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
@@ -561,6 +565,133 @@ def forward(
return attn_output, attn_weights, past_key_value
+# TODO @longjie no longer copied from Mistral after static cache
+class MixtralFlashAttention3(MixtralAttention):
+ """
+ Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
+ rotary_seq_len = (
+ max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
+ )
+
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
# copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
class MixtralSdpaAttention(MixtralAttention):
@@ -656,6 +787,7 @@ def forward(
MIXTRAL_ATTENTION_CLASSES = {
"eager": MixtralAttention,
"flash_attention_2": MixtralFlashAttention2,
+ "flash_attention_3": MixtralFlashAttention3,
"sdpa": MixtralSdpaAttention,
}
@@ -857,6 +989,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
_no_split_modules = ["MixtralDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
@@ -1122,7 +1255,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py
index f720faac038e51..26cf599a627033 100644
--- a/src/transformers/models/musicgen/modeling_musicgen.py
+++ b/src/transformers/models/musicgen/modeling_musicgen.py
@@ -47,6 +47,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -59,6 +60,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
if TYPE_CHECKING:
from ...generation.streamers import BaseStreamer
@@ -439,6 +443,125 @@ def forward(
return attn_output, attn_weights, past_key_value
+class MusicgenFlashAttention3(MusicgenAttention):
+ """
+ Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # MusicgenFlashAttention3 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("MusicgenFlashAttention3 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class MusicgenSdpaAttention(MusicgenAttention):
def forward(
self,
@@ -566,6 +689,7 @@ def forward(
"eager": MusicgenAttention,
"sdpa": MusicgenSdpaAttention,
"flash_attention_2": MusicgenFlashAttention2,
+ "flash_attention_3": MusicgenFlashAttention3,
}
@@ -703,6 +827,7 @@ class MusicgenPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
@@ -998,7 +1123,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
- if self.attn_implementation == "flash_attention_2":
+ if self.attn_implementation == "flash_attention_2" or self.attn_implementation == "flash_attention_3":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
@@ -1016,7 +1141,7 @@ def forward(
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self.attn_implementation == "flash_attention_2":
+ if self.attn_implementation == "flash_attention_2" or self.attn_implementation == "flash_attention_3":
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
@@ -1664,6 +1789,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel):
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def __init__(
diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
index a8a8fe96098952..5d83a44adf99b0 100644
--- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
+++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
@@ -39,6 +39,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -51,6 +52,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
if TYPE_CHECKING:
from ...generation.streamers import BaseStreamer
@@ -455,6 +459,125 @@ def forward(
return attn_output, attn_weights, past_key_value
+class MusicgenMelodyFlashAttention3(MusicgenMelodyAttention):
+ """
+ MusicgenMelody flash attention module. This module inherits from `MusicgenMelodyAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # MusicgenMelodyFlashAttention3 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("MusicgenMelodyFlashAttention3 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MusicgenMelody
class MusicgenMelodySdpaAttention(MusicgenMelodyAttention):
def forward(
@@ -566,6 +689,7 @@ def forward(
"eager": MusicgenMelodyAttention,
"sdpa": MusicgenMelodySdpaAttention,
"flash_attention_2": MusicgenMelodyFlashAttention2,
+ "flash_attention_3": MusicgenMelodyFlashAttention3,
}
@@ -662,6 +786,7 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
@@ -945,7 +1070,7 @@ def forward(
input_shape = inputs_embeds.size()[:-1]
- if self.attn_implementation == "flash_attention_2":
+ if self.attn_implementation == "flash_attention_2" or self.attn_implementation == "flash_attention_3":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self.attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
@@ -1590,6 +1715,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def __init__(
diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py
index 4d079b4dde104d..63afdf58cc68c3 100644
--- a/src/transformers/models/nemotron/modeling_nemotron.py
+++ b/src/transformers/models/nemotron/modeling_nemotron.py
@@ -27,6 +27,7 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, StaticCache
from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
from ...modeling_flash_attention_utils import _flash_attention_forward
from ...modeling_outputs import (
BaseModelOutputWithPast,
@@ -468,6 +469,108 @@ def forward(
return attn_output, attn_weights, past_key_value
+class NemotronFlashAttention3(NemotronAttention):
+ """
+ Nemotron flash attention module. This module inherits from `NemotronAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` "
+ "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers"
+ )
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is not None:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (NemotronRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
class NemotronSdpaAttention(NemotronAttention):
"""
@@ -563,6 +666,7 @@ def forward(
NEMOTRON_ATTENTION_CLASSES = {
"eager": NemotronAttention,
"flash_attention_2": NemotronFlashAttention2,
+ "flash_attention_3": NemotronFlashAttention3,
"sdpa": NemotronSdpaAttention,
}
@@ -677,6 +781,7 @@ class NemotronPreTrainedModel(PreTrainedModel):
_no_split_modules = ["NemotronDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -920,7 +1025,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py
index b4bda8e2db5251..4f1ebb3cf690c1 100644
--- a/src/transformers/models/olmo/modeling_olmo.py
+++ b/src/transformers/models/olmo/modeling_olmo.py
@@ -41,6 +41,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
@@ -52,6 +53,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -498,6 +502,106 @@ def forward(
return attn_output, attn_weights, past_key_value
+class OlmoFlashAttention3(OlmoAttention):
+ """
+ OLMo flash attention module. This module inherits from `OlmoAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ if self.config.clip_qkv is not None:
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (OlmoRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class OlmoSdpaAttention(OlmoAttention):
"""
OLMo attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -594,6 +698,7 @@ def forward(
OLMO_ATTENTION_CLASSES = {
"eager": OlmoAttention,
"flash_attention_2": OlmoFlashAttention2,
+ "flash_attention_3": OlmoFlashAttention3,
"sdpa": OlmoSdpaAttention,
}
@@ -704,6 +809,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
_no_split_modules = ["OlmoDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -959,7 +1065,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py
index a33338365312db..8efab00cc86cbf 100644
--- a/src/transformers/models/olmoe/modeling_olmoe.py
+++ b/src/transformers/models/olmoe/modeling_olmoe.py
@@ -34,6 +34,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -44,6 +45,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -572,6 +576,105 @@ def forward(
return attn_output, attn_weights, past_key_value
+class OlmoeFlashAttention3(OlmoeAttention):
+ """
+ OLMoE flash attention module. This module inherits from `OlmoeAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_norm(self.q_proj(hidden_states))
+ key_states = self.k_norm(self.k_proj(hidden_states))
+ value_states = self.v_proj(hidden_states)
+ if self.config.clip_qkv is not None:
+ query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+ value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (OlmoeRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class OlmoeSdpaAttention(OlmoeAttention):
"""
OLMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -670,6 +773,7 @@ def forward(
OLMOE_ATTENTION_CLASSES = {
"eager": OlmoeAttention,
"flash_attention_2": OlmoeFlashAttention2,
+ "flash_attention_3": OlmoeFlashAttention3,
"sdpa": OlmoeSdpaAttention,
}
@@ -838,6 +942,7 @@ class OlmoePreTrainedModel(PreTrainedModel):
_no_split_modules = ["OlmoeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -1109,7 +1214,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py
index 8f058171778efc..26a1499a7699e1 100644
--- a/src/transformers/models/opt/modeling_opt.py
+++ b/src/transformers/models/opt/modeling_opt.py
@@ -35,6 +35,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -45,6 +46,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -359,9 +363,119 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
+class OptFlashAttention3(OPTAttention):
+ """
+ OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
+ The only required change would be on the forward pass where it needs to correctly call the public API of flash
+ attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, _, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self.q_proj(hidden_states)
+ # get key, value proj
+ if is_cross_attention and past_key_value is not None:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ query_length = query_states.shape[1]
+ tgt_len = key_states.shape[-2]
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim)
+ key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
+ value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ query_length,
+ is_causal=self.is_causal,
+ )
+
+ attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
+ attn_output = self.out_proj(attn_weights_reshaped)
+
+ if not output_attentions:
+ attn_weights_reshaped = None
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
OPT_ATTENTION_CLASSES = {
"eager": OPTAttention,
"flash_attention_2": OptFlashAttention2,
+ "flash_attention_3": OptFlashAttention3,
}
@@ -488,6 +602,7 @@ class OPTPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["OPTDecoderLayer"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
def _init_weights(self, module):
std = self.config.init_std
@@ -604,6 +719,7 @@ def __init__(self, config: OPTConfig):
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@@ -702,7 +818,7 @@ def forward(
mask_seq_length = past_key_values_length + seq_length
# embed positions
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py
index 39ee57d70bb5d8..b335e4c211f40a 100644
--- a/src/transformers/models/paligemma/modeling_paligemma.py
+++ b/src/transformers/models/paligemma/modeling_paligemma.py
@@ -124,6 +124,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["PaliGemmaMultiModalProjector"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = False
+ _supports_flash_attn_3 = False
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py
index ccaa2c7fd29aae..02989bc9bc5471 100644
--- a/src/transformers/models/persimmon/modeling_persimmon.py
+++ b/src/transformers/models/persimmon/modeling_persimmon.py
@@ -783,7 +783,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py
index 648d1653a3b503..93450596f97049 100644
--- a/src/transformers/models/phi/modeling_phi.py
+++ b/src/transformers/models/phi/modeling_phi.py
@@ -41,6 +41,7 @@
add_start_docstrings_to_model_forward,
get_torch_version,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
@@ -52,6 +53,8 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
logger = logging.get_logger(__name__)
@@ -575,6 +578,136 @@ def forward(
return attn_output, attn_weights, past_key_value
+class PhiFlashAttention3(PhiAttention):
+ """
+ Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # PhiFlashAttention3 attention does not support output_attentions
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ if self.qk_layernorm:
+ query_states = self.q_layernorm(query_states)
+ key_states = self.k_layernorm(key_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+
+ # Partial rotary embedding
+ query_rot, query_pass = (
+ query_states[..., : self.rotary_ndims],
+ query_states[..., self.rotary_ndims :],
+ )
+ key_rot, key_pass = (
+ key_states[..., : self.rotary_ndims],
+ key_states[..., self.rotary_ndims :],
+ )
+ # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
+
+ # [batch_size, seq_length, num_heads, head_dim]
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
+
+ if past_key_value is not None:
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "partial_rotation_size": self.rotary_ndims,
+ "cache_position": cache_position,
+ }
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32.
+
+ if query_states.dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ softmax_scale=None,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.dense(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class PhiSdpaAttention(PhiAttention):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -704,6 +837,7 @@ def forward(
PHI_ATTENTION_CLASSES = {
"eager": PhiAttention,
"flash_attention_2": PhiFlashAttention2,
+ "flash_attention_3": PhiFlashAttention3,
"sdpa": PhiSdpaAttention,
}
@@ -812,6 +946,7 @@ class PhiPreTrainedModel(PreTrainedModel):
_no_split_modules = ["PhiDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
@@ -930,6 +1065,7 @@ def __init__(self, config: PhiConfig):
self.rotary_emb = PhiRotaryEmbedding(config=config)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
self.gradient_checkpointing = False
@@ -1074,7 +1210,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py
index ec395679ae6207..350070ed31e613 100644
--- a/src/transformers/models/phi3/modeling_phi3.py
+++ b/src/transformers/models/phi3/modeling_phi3.py
@@ -39,6 +39,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
@@ -50,6 +51,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
@@ -632,6 +636,146 @@ def forward(
return attn_output, attn_weights, past_key_value
+class Phi3FlashAttention3(Phi3Attention):
+ """
+ Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # Phi3FlashAttention3 attention does not support output_attentions
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ qkv = self.qkv_proj(hidden_states)
+ query_pos = self.num_heads * self.head_dim
+ query_states = qkv[..., :query_pos]
+ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
+ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
+ rotary_seq_len = (
+ max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len
+ )
+
+ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids)
+
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32.
+
+ if query_states.dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.qkv_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3
# TODO @Arthur no longer copied from LLama after static cache
class Phi3SdpaAttention(Phi3Attention):
@@ -729,6 +873,7 @@ def forward(
PHI3_ATTENTION_CLASSES = {
"eager": Phi3Attention,
"flash_attention_2": Phi3FlashAttention2,
+ "flash_attention_3": Phi3FlashAttention3,
"sdpa": Phi3SdpaAttention,
}
@@ -842,6 +987,7 @@ class Phi3PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Phi3DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
@@ -1095,7 +1241,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py
index 93d91e160089e1..03a24a4676261f 100644
--- a/src/transformers/models/plbart/modeling_plbart.py
+++ b/src/transformers/models/plbart/modeling_plbart.py
@@ -685,6 +685,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] =
)
self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
@@ -773,7 +774,7 @@ def forward(
# expand attention_mask
if attention_mask is not None:
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
attention_mask = attention_mask if 0 in attention_mask else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
@@ -870,6 +871,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] =
)
self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
@@ -990,7 +992,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input)
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
@@ -1010,7 +1012,7 @@ def forward(
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py
index d0ea8ef0e376e0..97dd3f12a9f6b0 100644
--- a/src/transformers/models/qwen2/modeling_qwen2.py
+++ b/src/transformers/models/qwen2/modeling_qwen2.py
@@ -42,6 +42,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
@@ -53,6 +54,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -539,6 +543,132 @@ def forward(
return attn_output, attn_weights, past_key_value
+class Qwen2FlashAttention3(Qwen2Attention):
+ """
+ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
+ as the weights of the module stays untouched. The only required change would be on the forward pass
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
+ config.max_window_layers layers.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ kv_seq_len = key_states.shape[-2] + cache_position[0]
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class Qwen2SdpaAttention(Qwen2Attention):
"""
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -638,6 +768,7 @@ def forward(
QWEN2_ATTENTION_CLASSES = {
"eager": Qwen2Attention,
"flash_attention_2": Qwen2FlashAttention2,
+ "flash_attention_3": Qwen2FlashAttention3,
"sdpa": Qwen2SdpaAttention,
}
@@ -647,7 +778,11 @@ def __init__(self, config: Qwen2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- if config.sliding_window and config._attn_implementation != "flash_attention_2":
+ if (
+ config.sliding_window
+ and config._attn_implementation != "flash_attention_2"
+ and config._attn_implementation != "flash_attention_3"
+ ):
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
@@ -754,6 +889,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Qwen2DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -1013,7 +1149,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
index 14235bf0aaf64a..7e3e4062a7c7cd 100644
--- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
+++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
@@ -30,6 +30,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -41,6 +42,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -338,6 +342,121 @@ def forward(
return attn_output, attn_weights, past_key_value
+class Qwen2AudioFlashAttention3(Qwen2AudioAttention):
+ """
+ Qwen2Audio flash attention module. This module inherits from `Qwen2AudioAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[EncoderDecoderCache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_3'`. "
+ "Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
+ )
+ # Qwen2AudioFlashAttention3 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("Qwen2AudioFlashAttention3 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
+
+ if past_key_value is not None:
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ past_key_value.is_updated[self.layer_idx] = True
+ past_key_value = past_key_value.cross_attention_cache
+ else:
+ past_key_value = past_key_value.self_attention_cache
+
+ # use key_value_states if cross attention
+ current_states = key_value_states if key_value_states is not None else hidden_states
+ if is_cross_attention and past_key_value and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value.key_cache[self.layer_idx]
+ value_states = past_key_value.value_cache[self.layer_idx]
+ else:
+ key_states = self._shape(self.k_proj(current_states), -1, bsz)
+ value_states = self._shape(self.v_proj(current_states), -1, bsz)
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
+ # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ causal_mask = attention_mask
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ causal_mask,
+ tgt_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, tgt_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
# Copied from transformers.models.whisper.modeling_whisper.WhisperSdpaAttention with Whisper->Qwen2Audio
class Qwen2AudioSdpaAttention(Qwen2AudioAttention):
def forward(
@@ -440,6 +559,7 @@ def forward(
QWEN2AUDIO_ATTENTION_CLASSES = {
"eager": Qwen2AudioAttention,
"flash_attention_2": Qwen2AudioFlashAttention2,
+ "flash_attention_3": Qwen2AudioFlashAttention3,
"sdpa": Qwen2AudioSdpaAttention,
}
@@ -543,6 +663,7 @@ class Qwen2AudioPreTrainedModel(PreTrainedModel):
_no_split_modules = ["Qwen2AudioAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
def _init_weights(self, module):
# important: this ported version of Qwen2Audio isn't meant for training from scratch - only
diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
index 6f483e50cde065..de9aa69bc620b3 100644
--- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
+++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
@@ -43,6 +43,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
@@ -54,6 +55,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "Qwen/Qwen1.5-MoE-A2.7B"
@@ -621,6 +625,132 @@ def forward(
return attn_output, attn_weights, past_key_value
+class Qwen2MoeFlashAttention3(Qwen2MoeAttention):
+ """
+ Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention`
+ as the weights of the module stays untouched. The only required change would be on the forward pass
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
+ config.max_window_layers layers.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ kv_seq_len = key_states.shape[-2] + cache_position[0]
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe
class Qwen2MoeSdpaAttention(Qwen2MoeAttention):
"""
@@ -721,6 +851,7 @@ def forward(
QWEN2MOE_ATTENTION_CLASSES = {
"eager": Qwen2MoeAttention,
"flash_attention_2": Qwen2MoeFlashAttention2,
+ "flash_attention_3": Qwen2MoeFlashAttention3,
"sdpa": Qwen2MoeSdpaAttention,
}
@@ -913,6 +1044,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
_no_split_modules = ["Qwen2MoeDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
@@ -1188,7 +1320,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
index 4e4e04198c0bf1..ba71a944ea2492 100644
--- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
+++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
@@ -44,6 +44,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -58,6 +59,12 @@
else:
flash_attn_varlen_func = None
+if is_flash_attn_3_available():
+ from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
+
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+else:
+ flash_attn_3_varlen_func = None
logger = logging.get_logger(__name__)
@@ -379,6 +386,29 @@ def forward(
return attn_output
+class VisionFlashAttention3(nn.Module):
+ def __init__(self, dim: int, num_heads: int = 16) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
+ self.proj = nn.Linear(dim, dim)
+
+ def forward(
+ self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
+ ) -> torch.Tensor:
+ seq_length = hidden_states.shape[0]
+ q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
+ q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0)
+ k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
+
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
+ attn_output = flash_attn_3_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
+ seq_length, -1
+ )
+ attn_output = self.proj(attn_output)
+ return attn_output
+
+
class VisionSdpaAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 16) -> None:
super().__init__()
@@ -410,6 +440,7 @@ def forward(
QWEN2_VL_VISION_ATTENTION_CLASSES = {
"eager": VisionAttention,
"flash_attention_2": VisionFlashAttention2,
+ "flash_attention_3": VisionFlashAttention3,
"sdpa": VisionSdpaAttention,
}
@@ -812,6 +843,144 @@ def forward(
return attn_output, attn_weights, past_key_value
+class Qwen2VLFlashAttention3(Qwen2VLAttention):
+ """
+ Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
+ as the weights of the module stays untouched. The only required change would be on the forward pass
+ where it needs to correctly call the public API of flash attention and deal with padding tokens
+ in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom
+ config.max_window_layers layers.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+
+ query_states, key_states = apply_multimodal_rotary_pos_emb(
+ query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
+ )
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class Qwen2VLSdpaAttention(Qwen2VLAttention):
"""
Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -917,6 +1086,7 @@ def forward(
QWEN2_VL_ATTENTION_CLASSES = {
"eager": Qwen2VLAttention,
"flash_attention_2": Qwen2VLFlashAttention2,
+ "flash_attention_3": Qwen2VLFlashAttention3,
"sdpa": Qwen2VLSdpaAttention,
}
@@ -926,7 +1096,11 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
+ if (
+ config.use_sliding_window
+ and config._attn_implementation != "flash_attention_2"
+ and config._attn_implementation != "flash_attention_3"
+ ):
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
@@ -1033,6 +1207,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
_no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
@@ -1274,7 +1449,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
index a8f076fad79c76..313471a599651f 100644
--- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
+++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
@@ -540,6 +540,7 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["RecurrentGemmaDecoderLayer"]
_skip_keys_device_placement = ["cache"]
_supports_flash_attn_2 = False
+ _supports_flash_attn_3 = False
_supports_sdpa = False # we can't compare with eager for now
_supports_cache_class = True
_supports_quantized_cache = True
diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py
index c9a3494b88b486..d5eb67a7fb376e 100644
--- a/src/transformers/models/sew/modeling_sew.py
+++ b/src/transformers/models/sew/modeling_sew.py
@@ -33,6 +33,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
)
@@ -43,6 +44,9 @@
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -682,6 +686,125 @@ def forward(
return attn_output, attn_weights, past_key_value
+class SEWFlashAttention3(SEWAttention):
+ """
+ SEW flash attention module. This module inherits from `SEWAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # SEWFlashAttention3 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("SEWFlashAttention3 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class SEWSdpaAttention(SEWAttention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->SEW
def forward(
@@ -793,6 +916,7 @@ def forward(
"eager": SEWAttention,
"sdpa": SEWSdpaAttention,
"flash_attention_2": SEWFlashAttention2,
+ "flash_attention_3": SEWFlashAttention3,
}
@@ -869,6 +993,7 @@ def __init__(self, config):
self.upsample = SEWUpsampling(config)
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -882,7 +1007,7 @@ def forward(
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
# 2d mask is passed through the layers
@@ -979,6 +1104,7 @@ class SEWPreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py
index 1d35d1d44cfd97..fe4c9eb243c689 100644
--- a/src/transformers/models/siglip/modeling_siglip.py
+++ b/src/transformers/models/siglip/modeling_siglip.py
@@ -35,6 +35,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -46,6 +47,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -522,6 +526,89 @@ def forward(
return attn_output, attn_weights
+class SiglipFlashAttention3(SiglipAttention):
+ """
+ SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ is_causal = False
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ output_attentions = False
+
+ batch_size, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32.
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights
+
+
class SiglipSdpaAttention(SiglipAttention):
"""
Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
@@ -591,6 +678,7 @@ def forward(
SIGLIP_ATTENTION_CLASSES = {
"eager": SiglipAttention,
"flash_attention_2": SiglipFlashAttention2,
+ "flash_attention_3": SiglipFlashAttention3,
"sdpa": SiglipSdpaAttention,
}
@@ -677,6 +765,7 @@ class SiglipPreTrainedModel(PreTrainedModel):
"SiglipMultiheadAttentionPoolingHead",
]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
@@ -930,6 +1019,7 @@ def __init__(self, config: SiglipTextConfig):
self.head = nn.Linear(embed_dim, embed_dim)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
@@ -962,7 +1052,7 @@ def forward(
# note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
# expand attention_mask
- if attention_mask is not None and not self._use_flash_attention_2:
+ if attention_mask is not None and not self._use_flash_attention_2 and not self._use_flash_attention_3:
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py
index d91c0832ed33da..11348590ce0cc4 100755
--- a/src/transformers/models/stablelm/modeling_stablelm.py
+++ b/src/transformers/models/stablelm/modeling_stablelm.py
@@ -42,6 +42,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -52,6 +53,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -672,10 +676,114 @@ def forward(
return attn_output, attn_weights, past_key_value
+class StableLmFlashAttention3(StableLmAttention):
+ """
+ StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.LongTensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # StableLmFlashAttention3 attention does not support output_attentions
+
+ output_attentions = False
+
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ # Flash attention requires the input to have the shape
+ # batch_size x seq_length x head_dim x hidden_dim
+ # therefore we just need to keep the original shape
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if self.qk_layernorm:
+ query_states = self.q_layernorm(query_states)
+ key_states = self.k_layernorm(key_states)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+
+ # Partial rotary embedding
+ query_rot, query_pass = (
+ query_states[..., : self.rotary_ndims],
+ query_states[..., self.rotary_ndims :],
+ )
+ key_rot, key_pass = (
+ key_states[..., : self.rotary_ndims],
+ key_states[..., self.rotary_ndims :],
+ )
+ query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin)
+
+ # [batch_size, seq_length, num_heads, head_dim]
+ query_states = torch.cat((query_rot, query_pass), dim=-1)
+ key_states = torch.cat((key_rot, key_pass), dim=-1)
+
+ if past_key_value is not None:
+ cache_kwargs = {
+ "sin": sin,
+ "cos": cos,
+ "partial_rotation_size": self.rotary_ndims,
+ "cache_position": cache_position,
+ }
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
+ # to be able to avoid many of these transpose/reshape/view.
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
ATTENTION_CLASSES = {
"eager": StableLmAttention,
"sdpa": StableLmSdpaAttention,
"flash_attention_2": StableLmFlashAttention2,
+ "flash_attention_3": StableLmFlashAttention3,
}
@@ -799,6 +907,7 @@ class StableLmPreTrainedModel(PreTrainedModel):
_no_split_modules = ["StableLmDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
_supports_sdpa = True
_supports_quantized_cache = True
@@ -1058,7 +1167,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py
index 0be37c4e1fb91c..bb0311477be828 100644
--- a/src/transformers/models/starcoder2/modeling_starcoder2.py
+++ b/src/transformers/models/starcoder2/modeling_starcoder2.py
@@ -42,6 +42,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_torchdynamo_compiling,
logging,
@@ -53,6 +54,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -512,6 +516,131 @@ def forward(
return attn_output, attn_weights, past_key_value
+class Starcoder2FlashAttention3(Starcoder2Attention):
+ """
+ Starcoder2 flash attention module. This module inherits from `Starcoder2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: bool = False,
+ use_cache: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
+ ):
+ bsz, q_len, _ = hidden_states.size()
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
+
+ if position_embeddings is None:
+ logger.warning_once(
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
+ "removed and `position_embeddings` will be mandatory."
+ )
+ cos, sin = self.rotary_emb(value_states, position_ids)
+ else:
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ kv_seq_len = key_states.shape[-2] + cache_position[0]
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in float16 just to be sure everything works as expected.
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ # Reashape to the expected shape for Flash Attention
+ query_states = query_states.transpose(1, 2)
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ position_ids=position_ids,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
+ attn_output = self.o_proj(attn_output)
+ attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Starcoder2
class Starcoder2SdpaAttention(Starcoder2Attention):
"""
@@ -614,6 +743,7 @@ def forward(
STARCODER2_ATTENTION_CLASSES = {
"eager": Starcoder2Attention,
"flash_attention_2": Starcoder2FlashAttention2,
+ "flash_attention_3": Starcoder2FlashAttention3,
"sdpa": Starcoder2SdpaAttention,
}
@@ -728,6 +858,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Starcoder2DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -988,7 +1119,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py
index 4202f680437c53..9e71d7c4f853fe 100755
--- a/src/transformers/models/unispeech/modeling_unispeech.py
+++ b/src/transformers/models/unispeech/modeling_unispeech.py
@@ -35,6 +35,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -45,6 +46,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -714,6 +718,125 @@ def forward(
return attn_output, attn_weights, past_key_value
+class UniSpeechFlashAttention3(UniSpeechAttention):
+ """
+ UniSpeech flash attention module. This module inherits from `UniSpeechAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # UniSpeechFlashAttention3 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("UniSpeechFlashAttention3 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class UniSpeechSdpaAttention(UniSpeechAttention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->UniSpeech
def forward(
@@ -825,6 +948,7 @@ def forward(
"eager": UniSpeechAttention,
"sdpa": UniSpeechSdpaAttention,
"flash_attention_2": UniSpeechFlashAttention2,
+ "flash_attention_3": UniSpeechFlashAttention3,
}
@@ -972,6 +1096,7 @@ def __init__(self, config):
self.layers = nn.ModuleList([UniSpeechEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -988,7 +1113,7 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1060,6 +1185,7 @@ def __init__(self, config):
)
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -1076,7 +1202,7 @@ def forward(
# make sure padded tokens are not attended to
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1218,6 +1344,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
index bfb2cbfa4f55da..0ec5b1c9836f9d 100755
--- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
+++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
@@ -42,6 +42,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_peft_available,
logging,
@@ -53,6 +54,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -731,6 +735,125 @@ def forward(
return attn_output, attn_weights, past_key_value
+class UniSpeechSatFlashAttention3(UniSpeechSatAttention):
+ """
+ UniSpeechSat flash attention module. This module inherits from `UniSpeechSatAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # UniSpeechSatFlashAttention3 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("UniSpeechSatFlashAttention3 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class UniSpeechSatSdpaAttention(UniSpeechSatAttention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->UniSpeechSat
def forward(
@@ -842,6 +965,7 @@ def forward(
"eager": UniSpeechSatAttention,
"sdpa": UniSpeechSatSdpaAttention,
"flash_attention_2": UniSpeechSatFlashAttention2,
+ "flash_attention_3": UniSpeechSatFlashAttention3,
}
@@ -989,6 +1113,7 @@ def __init__(self, config):
self.layers = nn.ModuleList([UniSpeechSatEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -1005,7 +1130,7 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1077,6 +1202,7 @@ def __init__(self, config):
)
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -1093,7 +1219,7 @@ def forward(
# make sure padded tokens are not attended to
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1235,6 +1361,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py
index 9ae80be65ae4b6..7fbc36945d9a0a 100644
--- a/src/transformers/models/video_llava/modeling_video_llava.py
+++ b/src/transformers/models/video_llava/modeling_video_llava.py
@@ -126,6 +126,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["VideoLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
def _init_weights(self, module):
diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py
index 53a3213697193e..f42e9333c39bec 100644
--- a/src/transformers/models/vipllava/modeling_vipllava.py
+++ b/src/transformers/models/vipllava/modeling_vipllava.py
@@ -132,6 +132,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["VipLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
def _init_weights(self, module):
diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py
index f1d021b58ee538..1333f3ff15ff60 100755
--- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py
@@ -45,6 +45,7 @@
add_start_docstrings_to_model_forward,
cached_file,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
is_peft_available,
is_safetensors_available,
@@ -64,6 +65,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -778,6 +782,125 @@ def forward(
return attn_output, attn_weights, past_key_value
+class Wav2Vec2FlashAttention3(Wav2Vec2Attention):
+ """
+ Wav2Vec2 flash attention module. This module inherits from `Wav2Vec2Attention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ # Wav2Vec2FlashAttention3 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("Wav2Vec2FlashAttention3 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, q_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0].transpose(1, 2)
+ value_states = past_key_value[1].transpose(1, 2)
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+ key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
+ value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
+ else:
+ # self_attention
+ key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
+ value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ kv_seq_len += past_key_value[0].shape[-2]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ q_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class Wav2Vec2SdpaAttention(Wav2Vec2Attention):
# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Wav2Vec2
def forward(
@@ -889,6 +1012,7 @@ def forward(
"eager": Wav2Vec2Attention,
"sdpa": Wav2Vec2SdpaAttention,
"flash_attention_2": Wav2Vec2FlashAttention2,
+ "flash_attention_3": Wav2Vec2FlashAttention3,
}
@@ -1006,6 +1130,7 @@ def __init__(self, config):
self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -1022,7 +1147,7 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1093,6 +1218,7 @@ def __init__(self, config):
)
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -1109,7 +1235,7 @@ def forward(
# make sure padded tokens are not attended to
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1331,6 +1457,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py
index b82b978e5e6d95..b78e3ea6925e79 100644
--- a/src/transformers/models/whisper/modeling_whisper.py
+++ b/src/transformers/models/whisper/modeling_whisper.py
@@ -39,6 +39,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -50,6 +51,9 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_flash_attn_3_available():
+ from ...modeling_flash_attention_3_utils import _flash_attention_3_forward
+
logger = logging.get_logger(__name__)
@@ -523,6 +527,121 @@ def forward(
return attn_output, attn_weights, past_key_value
+class WhisperFlashAttention3(WhisperAttention):
+ """
+ Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays
+ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
+ flash attention and deal with padding tokens in case the input contains any of them.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[EncoderDecoderCache] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if isinstance(past_key_value, StaticCache):
+ raise ValueError(
+ "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_3'`. "
+ "Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers"
+ )
+ # WhisperFlashAttention3 attention does not support output_attentions
+ if output_attentions:
+ raise ValueError("WhisperFlashAttention3 attention does not support output_attentions")
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+ bsz, tgt_len, _ = hidden_states.size()
+
+ # get query proj
+ query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim))
+
+ if past_key_value is not None:
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ past_key_value.is_updated[self.layer_idx] = True
+ past_key_value = past_key_value.cross_attention_cache
+ else:
+ past_key_value = past_key_value.self_attention_cache
+
+ # use key_value_states if cross attention
+ current_states = key_value_states if key_value_states is not None else hidden_states
+ if is_cross_attention and past_key_value and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = past_key_value.key_cache[self.layer_idx]
+ value_states = past_key_value.value_cache[self.layer_idx]
+ else:
+ key_states = self._shape(self.k_proj(current_states), -1, bsz)
+ value_states = self._shape(self.v_proj(current_states), -1, bsz)
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+
+ # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]
+ # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view.
+ key_states = key_states.transpose(1, 2)
+ value_states = value_states.transpose(1, 2)
+
+ causal_mask = attention_mask
+ if attention_mask is not None: # no matter the length, we just slice it
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
+
+ # In PEFT, usually we cast the layer norms in float32 for training stability reasons
+ # therefore the input hidden states gets silently casted in float32. Hence, we need
+ # cast them back in the correct dtype just to be sure everything works as expected.
+ # This might slowdown training & inference so it is recommended to not cast the LayerNorms
+ # in fp32. (LlamaRMSNorm handles it correctly)
+
+ input_dtype = query_states.dtype
+ if input_dtype == torch.float32:
+ if torch.is_autocast_enabled():
+ target_dtype = torch.get_autocast_gpu_dtype()
+ # Handle the case where the model is quantized
+ elif hasattr(self.config, "_pre_quantization_dtype"):
+ target_dtype = self.config._pre_quantization_dtype
+ else:
+ target_dtype = self.q_proj.weight.dtype
+
+ logger.warning_once(
+ f"The input hidden states seems to be silently casted in float32, this might be related to"
+ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
+ f" {target_dtype}."
+ )
+
+ query_states = query_states.to(target_dtype)
+ key_states = key_states.to(target_dtype)
+ value_states = value_states.to(target_dtype)
+
+ attn_output = _flash_attention_3_forward(
+ query_states,
+ key_states,
+ value_states,
+ causal_mask,
+ tgt_len,
+ is_causal=self.is_causal,
+ )
+
+ attn_output = attn_output.reshape(bsz, tgt_len, -1)
+ attn_output = self.out_proj(attn_output)
+
+ if not output_attentions:
+ attn_weights = None
+
+ return attn_output, attn_weights, past_key_value
+
+
class WhisperSdpaAttention(WhisperAttention):
def forward(
self,
@@ -624,6 +743,7 @@ def forward(
WHISPER_ATTENTION_CLASSES = {
"eager": WhisperAttention,
"flash_attention_2": WhisperFlashAttention2,
+ "flash_attention_3": WhisperFlashAttention3,
"sdpa": WhisperSdpaAttention,
}
@@ -823,6 +943,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
@@ -1166,6 +1287,7 @@ def __init__(self, config: WhisperConfig):
[WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)]
)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layer_norm = nn.LayerNorm(config.d_model)
@@ -1428,7 +1550,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index 3306f76249fe9f..1b20c2355557ba 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -71,6 +71,7 @@
is_faiss_available,
is_fbgemm_gpu_available,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flax_available,
is_fsdp_available,
is_ftfy_available,
@@ -519,6 +520,16 @@ def require_flash_attn(test_case):
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
+def require_flash_attn_3(test_case):
+ """
+ Decorator marking a test that requires Flash Attention 3.
+
+ These tests are skipped when Flash Attention 3 isn't installed.
+
+ """
+ return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case)
+
+
def require_torch_sdpa(test_case):
"""
Decorator marking a test that requires PyTorch's SDPA.
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index eee350349f5565..92a9b1fb675623 100755
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -130,6 +130,7 @@
is_faiss_available,
is_fbgemm_gpu_available,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_flax_available,
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index ad8b649aaa4e84..4f4e70b122be6f 100755
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -62,6 +62,14 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
except ImportError:
# If the package can't be imported, it's not available
package_exists = False
+ elif pkg_name == "flash_attn_interface":
+ try:
+ package = importlib.import_module(pkg_name)
+ package_version = getattr(package, "__version__", "N/A")
+ package_exists = True
+ except ImportError:
+ # If the package can't be imported, it's not available
+ package_exists = False
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
@@ -899,6 +907,16 @@ def is_flash_attn_greater_or_equal(library_version: str):
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
+def is_flash_attn_3_available():
+ if not is_flash_attn_2_available():
+ return False
+
+ if not _is_package_available("flash_attn_interface"):
+ return False
+
+ return True
+
+
def is_torchdistx_available():
return _torchdistx_available
diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py
index 9bb8ef33d75998..9c6064cf55a294 100644
--- a/tests/models/bark/test_modeling_bark.py
+++ b/tests/models/bark/test_modeling_bark.py
@@ -35,6 +35,7 @@
)
from transformers.testing_utils import (
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_fp16,
require_torch_gpu,
@@ -981,6 +982,63 @@ def test_flash_attn_2_inference_equivalence(self):
model.train()
_ = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(reason="Model does not support flash_attention_3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ dummy_input = inputs_dict["input_ids"][:1]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ dummy_attention_mask = dummy_attention_mask[:1]
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+
+ outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
+
+ logits = outputs.hidden_states[-1]
+ logits_fa = outputs_fa.hidden_states[-1]
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ other_inputs = {"output_hidden_states": True}
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
+ outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
+
+ logits = outputs.hidden_states[-1]
+ logits_fa = outputs_fa.hidden_states[-1]
+
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
+
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -1039,6 +1097,64 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(reason="Model does not support flash_attention_3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.bfloat16,
+ )
+ model.to(torch_device)
+
+ dummy_input = inputs_dict["input_ids"][:1]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ dummy_attention_mask = dummy_attention_mask[:1]
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
+
+ logits = outputs.hidden_states[-1]
+ logits_fa = outputs_fa.hidden_states[-1]
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
+ outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
+
+ logits = outputs.hidden_states[-1]
+ logits_fa = outputs_fa.hidden_states[-1]
+
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+
@require_torch
class BarkModelIntegrationTests(unittest.TestCase):
diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py
index 16e0a548e6dc47..6635c1f91e9aaf 100644
--- a/tests/models/chameleon/test_modeling_chameleon.py
+++ b/tests/models/chameleon/test_modeling_chameleon.py
@@ -24,6 +24,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_gpu,
@@ -366,6 +367,43 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertListEqual(output_native, output_fa_2)
+ @require_flash_attn_3
+ @require_read_token
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ model = ChameleonForConditionalGeneration.from_pretrained(
+ "facebook/chameleon-7b",
+ load_in_4bit=True,
+ device_map={"": 0},
+ )
+
+ processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
+ texts = ["hi", "Hello this is a very long sentence"]
+
+ processor.tokenizer.padding_side = "right"
+
+ inputs = processor(texts, return_tensors="pt", padding=True).to(0)
+
+ output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_native = processor.tokenizer.batch_decode(output_native)
+
+ model = ChameleonForConditionalGeneration.from_pretrained(
+ "facebook/chameleon-7b",
+ load_in_4bit=True,
+ attn_implementation="flash_attention_3",
+ )
+
+ output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_fa_3 = processor.tokenizer.batch_decode(output_fa_3)
+
+ self.assertListEqual(output_native, output_fa_3)
+
@unittest.skip("Chameleon forces some token ids to be -inf!")
def test_batching_equivalence(self):
pass
diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py
index 3b6994428088a2..a63a8236a087c3 100644
--- a/tests/models/clip/test_modeling_clip.py
+++ b/tests/models/clip/test_modeling_clip.py
@@ -31,6 +31,7 @@
is_flax_available,
is_pt_flax_cross_test,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_sdpa,
@@ -1019,6 +1020,45 @@ def test_flash_attn_2_inference_equivalence(self):
f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
+ dummy_input_ids = inputs_dict["input_ids"]
+
+ outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(
+ pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
+ )
+
+ self.assertTrue(
+ torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
+ f"Image logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
+ )
+ self.assertTrue(
+ torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
+ f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
+ )
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -1070,6 +1110,57 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
f"Text logits max diff: {torch.max(torch.abs(logits_per_text_eager - logits_per_text_sdpa))}",
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager"
+ )
+ model.to(torch_device)
+
+ dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
+ dummy_input_ids = inputs_dict["input_ids"]
+ dummy_pixel_mask = inputs_dict["attention_mask"]
+
+ # right padding
+ dummy_pixel_mask[:] = 1
+ dummy_pixel_mask[:, -1:] = 0
+
+ outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(
+ pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
+ )
+
+ logits_per_image_eager = outputs.logits_per_image[:, :-1]
+ logits_per_text_eager = outputs.logits_per_text[:, :-1]
+
+ logits_per_image_sdpa = outputs_fa.logits_per_image[:, :-1]
+ logits_per_text_sdpa = outputs_fa.logits_per_text[:, :-1]
+
+ self.assertTrue(
+ torch.allclose(logits_per_image_eager, logits_per_image_sdpa, atol=4e-2, rtol=4e-2),
+ f"Image logits max diff: {torch.max(torch.abs(logits_per_image_eager - logits_per_image_sdpa))}",
+ )
+ self.assertTrue(
+ torch.allclose(logits_per_text_eager, logits_per_text_sdpa, atol=4e-2, rtol=4e-2),
+ f"Text logits max diff: {torch.max(torch.abs(logits_per_text_eager - logits_per_text_sdpa))}",
+ )
+
class CLIPForImageClassificationModelTester(CLIPModelTester):
def __init__(self, parent):
diff --git a/tests/models/distilbert/test_modeling_distilbert.py b/tests/models/distilbert/test_modeling_distilbert.py
index 3a74a1557cf9ba..00c5eb2a0e1ec0 100644
--- a/tests/models/distilbert/test_modeling_distilbert.py
+++ b/tests/models/distilbert/test_modeling_distilbert.py
@@ -19,7 +19,14 @@
import pytest
from transformers import DistilBertConfig, is_torch_available
-from transformers.testing_utils import require_flash_attn, require_torch, require_torch_accelerator, slow, torch_device
+from transformers.testing_utils import (
+ require_flash_attn,
+ require_flash_attn_3,
+ require_torch,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
@@ -348,6 +355,58 @@ def test_flash_attn_2_inference_equivalence(self):
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
+ # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
+ @require_flash_attn_3
+ @require_torch_accelerator
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence(self):
+ import torch
+
+ for model_class in self.all_model_classes:
+ dummy_input = torch.LongTensor(
+ [
+ [1, 2, 3, 4],
+ [1, 2, 8, 9],
+ [1, 2, 11, 12],
+ [1, 2, 13, 14],
+ ]
+ ).to(torch_device)
+ dummy_attention_mask = torch.LongTensor(
+ [
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ ]
+ ).to(torch_device)
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
+ logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
+
+ self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
+
+ output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
+ logits_fa = output_fa.hidden_states[-1]
+
+ output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
+ logits = output.hidden_states[-1]
+
+ self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
+
# Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
@require_flash_attn
@require_torch_accelerator
@@ -403,6 +462,61 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2))
+ # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
+ @require_flash_attn_3
+ @require_torch_accelerator
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ import torch
+
+ for model_class in self.all_model_classes:
+ dummy_input = torch.LongTensor(
+ [
+ [1, 2, 3, 4],
+ [1, 2, 8, 9],
+ [1, 2, 11, 12],
+ [1, 2, 13, 14],
+ ]
+ ).to(torch_device)
+ dummy_attention_mask = torch.LongTensor(
+ [
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ ]
+ ).to(torch_device)
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.bfloat16,
+ )
+ model.to(torch_device)
+
+ logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
+ logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
+
+ self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
+
+ output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
+ logits_fa = output_fa.hidden_states[-1]
+
+ output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
+ logits = output.hidden_states[-1]
+
+ self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2))
+
@require_torch
class DistilBertModelIntergrationTest(unittest.TestCase):
diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py
index b564d51216d29d..f7dfe7acf7a360 100644
--- a/tests/models/gemma/test_modeling_gemma.py
+++ b/tests/models/gemma/test_modeling_gemma.py
@@ -25,6 +25,7 @@
is_flaky,
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_accelerator,
@@ -453,6 +454,51 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Gemma apparently does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -460,6 +506,13 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="Gemma flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest(reason="Gemma flash attention 3 does not support right padding")
+
@require_torch_sdpa
@require_torch_accelerator
@slow
@@ -526,6 +579,40 @@ def test_flash_attn_2_equivalence(self):
# gemma flash attention 2 needs a high tolerance
assert torch.allclose(logits_fa, logits, atol=3e-3)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @is_flaky()
+ @slow
+ def test_flash_attn_3_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(reason="Model does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ dummy_input = dummy_input.to(torch_device)
+ outputs = model(dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = outputs.hidden_states[-1]
+ logits_fa = outputs_fa.hidden_states[-1]
+
+ # gemma flash attention 3 needs a high tolerance
+ assert torch.allclose(logits_fa, logits, atol=3e-3)
+
@slow
@require_torch_gpu
@@ -652,6 +739,29 @@ def test_model_2b_flash_attn(self):
self.assertEqual(output_text, EXPECTED_TEXTS)
+ @require_flash_attn_3
+ @require_read_token
+ @pytest.mark.flash_attn_3_test
+ def test_model_2b_flash_attn_fa3(self):
+ model_id = "google/gemma-2b"
+ EXPECTED_TEXTS = [
+ "Hello I am doing a project on the 1990s and I need to know what the most popular music",
+ "Hi today I am going to share with you a very easy and simple recipe of Kaju Kat",
+ ]
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model.to(torch_device)
+
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
+
+ output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
+
+ self.assertEqual(output_text, EXPECTED_TEXTS)
+
@require_bitsandbytes
@require_read_token
def test_model_2b_4bit(self):
diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py
index 918ed847f83d9e..22b308300e5e48 100644
--- a/tests/models/gemma2/test_modeling_gemma2.py
+++ b/tests/models/gemma2/test_modeling_gemma2.py
@@ -22,6 +22,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline
from transformers.testing_utils import (
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_gpu,
@@ -301,3 +302,27 @@ def test_model_9b_flash_attn(self):
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
self.assertEqual(output_text, EXPECTED_TEXTS)
+
+ @require_read_token
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_model_9b_flash_attn_3(self):
+ # See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
+ model_id = "google/gemma-2-9b"
+ EXPECTED_TEXTS = [
+ 'Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
+ "Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the"
+ ] # fmt: skip
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id, attn_implementation="flash_attention_3", torch_dtype="float16"
+ ).to(torch_device)
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
+
+ output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
+ output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
+
+ self.assertEqual(output_text, EXPECTED_TEXTS)
diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py
index 3f96c20ab2dbd9..30d3c4d1b9bab3 100644
--- a/tests/models/gpt2/test_modeling_gpt2.py
+++ b/tests/models/gpt2/test_modeling_gpt2.py
@@ -25,6 +25,7 @@
from transformers.testing_utils import (
backend_empty_cache,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
slow,
@@ -939,3 +940,40 @@ def test_flash_attn_2_generate_padding_left(self):
self.assertListEqual(output_native, output_fa_2)
self.assertListEqual(output_native, expected_output)
+
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_left(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ model = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=torch.float16).to(0)
+
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+
+ texts = ["hi", "Hello this is a very long sentence"]
+
+ tokenizer.padding_side = "left"
+ tokenizer.pad_token = tokenizer.eos_token
+
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
+
+ output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_native = tokenizer.batch_decode(output_native)
+
+ model = GPT2LMHeadModel.from_pretrained(
+ "gpt2", device_map={"": 0}, attn_implementation="flash_attention_3", torch_dtype=torch.float16
+ )
+
+ output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_fa_3 = tokenizer.batch_decode(output_fa_3)
+
+ expected_output = [
+ "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>hi, who was born in the city of Kolkata, was a member of the Kolkata",
+ "Hello this is a very long sentence. I'm sorry. I'm sorry. I'm sorry. I'm sorry. I'm sorry",
+ ]
+
+ self.assertListEqual(output_native, output_fa_3)
+ self.assertListEqual(output_native, expected_output)
diff --git a/tests/models/gptj/test_modeling_gptj.py b/tests/models/gptj/test_modeling_gptj.py
index 2ef2e391215e7b..02fbb3703d8403 100644
--- a/tests/models/gptj/test_modeling_gptj.py
+++ b/tests/models/gptj/test_modeling_gptj.py
@@ -23,6 +23,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
slow,
@@ -565,6 +566,44 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertListEqual(expected_outputs, output_fa_2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
+
+ texts = ["hi", "Hello this is a very long sentence"]
+ expected_outputs = [
+ "hi<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Q: I have a question about the new version of the game. I have a question about the",
+ "Hello this is a very long sentence.\n\nA:\n\nI think the best way to understand this is to think of it",
+ ]
+
+ tokenizer.padding_side = "right"
+ tokenizer.pad_token = tokenizer.eos_token
+
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
+
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
+
+ model = GPTJForCausalLM.from_pretrained(
+ "EleutherAI/gpt-j-6b",
+ device_map={"": 0},
+ attn_implementation="flash_attention_3",
+ revision="float16",
+ torch_dtype=torch.float16,
+ quantization_config=quantization_config,
+ )
+
+ output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_fa_3 = tokenizer.batch_decode(output_fa_3)
+
+ self.assertListEqual(expected_outputs, output_fa_3)
+
@require_torch
class GPTJModelLanguageGenerationTest(unittest.TestCase):
diff --git a/tests/models/granite/test_modeling_granite.py b/tests/models/granite/test_modeling_granite.py
index 8771cd50978a7f..933e3b798c7232 100644
--- a/tests/models/granite/test_modeling_granite.py
+++ b/tests/models/granite/test_modeling_granite.py
@@ -24,6 +24,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_gpu,
@@ -518,6 +519,46 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertListEqual(output_native, output_fa_2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_3_test
+ @require_read_token
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ model = GraniteForCausalLM.from_pretrained(
+ "ibm/PowerLM-3b",
+ load_in_4bit=True,
+ device_map={"": 0},
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained("ibm/PowerLM-3b")
+
+ texts = ["hi", "Hello this is a very long sentence"]
+
+ tokenizer.padding_side = "right"
+ tokenizer.pad_token = tokenizer.eos_token
+
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
+
+ output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_native = tokenizer.batch_decode(output_native)
+
+ model = GraniteForCausalLM.from_pretrained(
+ "ibm/PowerLM-3b",
+ load_in_4bit=True,
+ device_map={"": 0},
+ attn_implementation="flash_attention_3",
+ )
+
+ output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_fa_3 = tokenizer.batch_decode(output_fa_3)
+
+ self.assertListEqual(output_native, output_fa_3)
+
@require_flash_attn
@require_torch_gpu
@slow
diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py
index e02c5b4c9f09c6..51f8604aad17b2 100644
--- a/tests/models/idefics2/test_modeling_idefics2.py
+++ b/tests/models/idefics2/test_modeling_idefics2.py
@@ -32,6 +32,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
@@ -608,3 +609,37 @@ def test_flash_attn_2_eager_equivalence(self):
)
self.assertEqual(generated_texts_eager[0], generated_texts_flash_attention_2[0])
+
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ def test_flash_attn_3_eager_equivalence(self):
+ # Create inputs
+ text = "In this image, we see"
+ images = self.image1
+ inputs = self.processor(text=text, images=images, return_tensors="pt", padding=True)
+ inputs.to(torch_device)
+
+ # Eager model
+ model_eager = Idefics2ForConditionalGeneration.from_pretrained(
+ "HuggingFaceM4/idefics2-8b-base",
+ attn_implementation="eager",
+ load_in_4bit=True,
+ )
+ generated_ids_eager = model_eager.generate(**inputs, max_new_tokens=10)
+ generated_texts_eager = self.processor.batch_decode(generated_ids_eager, skip_special_tokens=True)
+
+ del model_eager
+
+ # Flash Attention 3 model
+ model_flash_attention_3 = Idefics2ForConditionalGeneration.from_pretrained(
+ "HuggingFaceM4/idefics2-8b-base",
+ attn_implementation="flash_attention_3",
+ load_in_4bit=True,
+ )
+ generated_ids_flash_attention_3 = model_flash_attention_3.generate(**inputs, max_new_tokens=10)
+ generated_texts_flash_attention_3 = self.processor.batch_decode(
+ generated_ids_flash_attention_3, skip_special_tokens=True
+ )
+
+ self.assertEqual(generated_texts_eager[0], generated_texts_flash_attention_3[0])
diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py
index 6e1a2cf2cf9c44..28fa0e8749c827 100644
--- a/tests/models/jamba/test_modeling_jamba.py
+++ b/tests/models/jamba/test_modeling_jamba.py
@@ -25,6 +25,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
slow,
@@ -539,6 +540,45 @@ def test_flash_attn_2_fp32_ln(self):
# with attention mask
_ = model(dummy_input, attention_mask=dummy_attention_mask)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_fp32_ln(self):
+ r"""
+ Overriding the test_flash_attn_3_fp32_ln test as the Jamba model, like Mixtral, doesn't support
+ right padding + use cache with FA3
+ """
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_input = inputs_dict[model.main_input_name]
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Jamba does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ load_in_4bit=True,
+ )
+
+ for _, param in model.named_parameters():
+ # upcast only layer norms
+ if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
+ param.data = param.data.to(torch.float32)
+
+ _ = model(dummy_input)
+ # with attention mask
+ _ = model(dummy_input, attention_mask=dummy_attention_mask)
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -577,6 +617,44 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ r"""
+ Overriding the test_flash_attn_3_generate_padding_right test as the Jamba model, like Mixtral, doesn't support
+ right padding + use cache with FA3
+ """
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -626,6 +704,55 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ r"""
+ Overriding the test_flash_attn_3_generate_use_cache test as the Jamba model, like Mixtral, doesn't support
+ right padding + use cache with FA3
+ """
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Jamba does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -637,6 +764,17 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
"""
self.skipTest(reason="Jamba flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ r"""
+ Overriding the test_flash_attn_3_inference_padding_right test as the Jamba model, like Mixtral, doesn't support
+ right padding + use cache with FA3
+ """
+ self.skipTest(reason="Jamba flash attention does not support right padding")
+
@unittest.skip(reason="Jamba has its own special cache type")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
diff --git a/tests/models/jetmoe/test_modeling_jetmoe.py b/tests/models/jetmoe/test_modeling_jetmoe.py
index 50fd7a27e1e6d1..238d2d7219312f 100644
--- a/tests/models/jetmoe/test_modeling_jetmoe.py
+++ b/tests/models/jetmoe/test_modeling_jetmoe.py
@@ -26,6 +26,7 @@
backend_empty_cache,
is_flaky,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_sdpa,
@@ -420,6 +421,40 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -465,6 +500,51 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: JetMoe apparently does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -472,6 +552,13 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="JetMoe flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest(reason="JetMoe flash attention does not support right padding")
+
@require_torch
class JetMoeIntegrationTest(unittest.TestCase):
diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py
index c99357ff99b257..f03352b10d4f09 100644
--- a/tests/models/llama/test_modeling_llama.py
+++ b/tests/models/llama/test_modeling_llama.py
@@ -26,6 +26,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_gpu,
@@ -617,6 +618,43 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertListEqual(output_native, output_fa_2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_3_test
+ @require_read_token
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ model = LlamaForCausalLM.from_pretrained(
+ "meta-llama/Llama-2-7b-hf",
+ load_in_4bit=True,
+ device_map={"": 0},
+ )
+
+ tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+
+ texts = ["hi", "Hello this is a very long sentence"]
+
+ tokenizer.padding_side = "right"
+ tokenizer.pad_token = tokenizer.eos_token
+
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
+
+ output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_native = tokenizer.batch_decode(output_native)
+
+ model = LlamaForCausalLM.from_pretrained(
+ "meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_3"
+ )
+
+ output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_fa_3 = tokenizer.batch_decode(output_fa_3)
+
+ self.assertListEqual(output_native, output_fa_3)
+
@require_flash_attn
@require_torch_gpu
@slow
diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py
index a29a9c8a9ec0dc..35701f0c0eb9d8 100644
--- a/tests/models/m2m_100/test_modeling_m2m_100.py
+++ b/tests/models/m2m_100/test_modeling_m2m_100.py
@@ -23,6 +23,7 @@
from transformers import M2M100Config, is_torch_available
from transformers.testing_utils import (
require_flash_attn,
+ require_flash_attn_3,
require_sentencepiece,
require_tokenizers,
require_torch,
@@ -465,3 +466,48 @@ def test_flash_attn_2_seq_to_seq_generation(self):
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
)
assert generated == expected_en
+
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_seq_to_seq_generation(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ model = M2M100ForConditionalGeneration.from_pretrained(
+ "facebook/m2m100_418M", attn_implementation="flash_attention_3"
+ ).to(torch_device)
+
+ tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en")
+
+ src_fr = [
+ "L'affaire NSA souligne l'absence totale de débat sur le renseignement",
+ "Selon moi, il y a deux niveaux de réponse de la part du gouvernement français.",
+ "Lorsque François Hollande téléphone à Barack Obama ou quand le ministre des affaires étrangères Laurent"
+ " Fabius convoque l'ambassadeur des Etats-Unis, ils réagissent à une vraie découverte, qui est celle de"
+ " l'ampleur de la surveillance américaine sur l'ensemble des communications en France.",
+ ]
+
+ # The below article tests that we don't add any hypotheses outside of the top n_beams
+ dct = tokenizer(src_fr, padding=True, return_tensors="pt")
+
+ hypotheses_batch = model.generate(
+ input_ids=dct["input_ids"].to(torch_device),
+ attention_mask=dct["attention_mask"].to(torch_device),
+ num_beams=5,
+ forced_bos_token_id=tokenizer.get_lang_id("en"),
+ )
+
+ expected_en = [
+ "The NSA case highlights the total absence of intelligence debate",
+ "I think there are two levels of response from the French government.",
+ "When François Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S."
+ " Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all"
+ " communications in France.",
+ ]
+
+ generated = tokenizer.batch_decode(
+ hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
+ )
+ assert generated == expected_en
diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py
index 0da7ae72add7bd..c1d088da723369 100644
--- a/tests/models/mistral/test_modeling_mistral.py
+++ b/tests/models/mistral/test_modeling_mistral.py
@@ -27,6 +27,7 @@
is_flaky,
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_gpu,
@@ -439,6 +440,40 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -484,6 +519,51 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Mistral apparently does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -491,6 +571,13 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="Mistral flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest(reason="Mistral flash attention does not support right padding")
+
@require_torch_gpu
class MistralIntegrationTest(unittest.TestCase):
@@ -601,6 +688,31 @@ def test_model_7b_long_prompt(self):
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
+ @require_flash_attn_3
+ @require_bitsandbytes
+ @slow
+ @pytest.mark.flash_attn_3_test
+ def test_model_7b_long_prompt_fa3(self):
+ EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
+ # An input with 4097 tokens that is above the size of the sliding window
+ input_ids = [1] + [306, 338] * 2048
+ model = MistralForCausalLM.from_pretrained(
+ "mistralai/Mistral-7B-v0.1",
+ device_map={"": torch_device},
+ load_in_4bit=True,
+ attn_implementation="flash_attention_3",
+ )
+ input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
+ generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
+ self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
+
+ # Assisted generation
+ assistant_model = model
+ assistant_model.generation_config.num_assistant_tokens = 2
+ assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
+ generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
+ self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
+
@slow
@require_torch_sdpa
def test_model_7b_long_prompt_sdpa(self):
diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py
index db9641e3dcb2a9..6e074dedbe134e 100644
--- a/tests/models/mixtral/test_modeling_mixtral.py
+++ b/tests/models/mixtral/test_modeling_mixtral.py
@@ -23,6 +23,7 @@
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_sdpa,
@@ -440,6 +441,40 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -485,6 +520,51 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Mixtral apparently does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -492,6 +572,13 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="Mixtral flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest(reason="Mixtral flash attention does not support right padding")
+
# Ignore copy
def test_load_balancing_loss(self):
r"""
diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py
index 870a4c92767b9b..71a27d248aa004 100644
--- a/tests/models/musicgen/test_modeling_musicgen.py
+++ b/tests/models/musicgen/test_modeling_musicgen.py
@@ -35,6 +35,7 @@
from transformers.testing_utils import (
is_torch_available,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_accelerator,
require_torch_fp16,
@@ -396,6 +397,86 @@ def test_flash_attn_2_inference_equivalence(self):
model.train()
_ = model_fa(dummy_input, **other_inputs)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence
+ def test_flash_attn_3_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+
+ # Ignore copy
+ outputs = model(dummy_input, output_hidden_states=True)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ # Ignore copy
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
+
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(dummy_input, **other_inputs)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -475,6 +556,85 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence_right_padding
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ if model.config.is_encoder_decoder:
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
+
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ else:
+ outputs = model(dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+ # Ignore copy
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -521,6 +681,52 @@ def test_flash_attn_2_generate_left_padding(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_left_padding
+ def test_flash_attn_3_generate_left_padding(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # make sure we do left padding
+ dummy_attention_mask[:, :-1] = 0
+ dummy_attention_mask[:, -1:] = 1
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -567,6 +773,52 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_padding_right
+ def test_flash_attn_3_generate_padding_right(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # make sure we do right padding
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -613,6 +865,52 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_use_cache
+ def test_flash_attn_3_generate_use_cache(self):
+ max_new_tokens = 30
+
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@@ -1591,36 +1889,121 @@ def test_greedy_generate_stereo_outputs(self):
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.audio_channels = 2
- model = model_class(config).to(torch_device).eval()
- output_generate = self._greedy_generate(
- model=model,
- input_ids=input_ids.to(torch_device),
- attention_mask=attention_mask.to(torch_device),
- output_scores=True,
- output_hidden_states=True,
- output_attentions=True,
- return_dict_in_generate=True,
- )
+ model = model_class(config).to(torch_device).eval()
+ output_generate = self._greedy_generate(
+ model=model,
+ input_ids=input_ids.to(torch_device),
+ attention_mask=attention_mask.to(torch_device),
+ output_scores=True,
+ output_hidden_states=True,
+ output_attentions=True,
+ return_dict_in_generate=True,
+ )
+
+ self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
+
+ self.assertNotIn(config.pad_token_id, output_generate)
+
+ @unittest.skip(
+ reason="MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model"
+ )
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ @require_flash_attn
+ @require_torch_gpu
+ @mark.flash_attn_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
+ def test_flash_attn_2_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+
+ # Ignore copy
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
+ # Ignore copy
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+ # Ignore copy
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ # Ignore copy
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+ # Ignore copy
+ outputs = model(dummy_input, **other_inputs)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, **other_inputs)
- self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
- self.assertNotIn(config.pad_token_id, output_generate)
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
- @unittest.skip(
- reason="MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model"
- )
- def test_save_load_fast_init_from_base(self):
- pass
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(dummy_input, **other_inputs)
- @require_flash_attn
+ @require_flash_attn_3
@require_torch_gpu
- @mark.flash_attn_test
+ @mark.flash_attn_3_test
@slow
- # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
- def test_flash_attn_2_inference_equivalence(self):
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence
+ def test_flash_attn_3_inference_equivalence(self):
for model_class in self.all_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@@ -1628,7 +2011,7 @@ def test_flash_attn_2_inference_equivalence(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
)
model_fa.to(torch_device)
@@ -1779,6 +2162,88 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence_right_padding
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ # Ignore copy
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
+ # Ignore copy
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ # Ignore copy
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ # Ignore copy
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+ # Ignore copy
+ outputs = model(dummy_input, **other_inputs)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -1828,6 +2293,55 @@ def test_flash_attn_2_generate_left_padding(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_left_padding
+ def test_flash_attn_3_generate_left_padding(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask")
+ if dummy_attention_mask is None:
+ dummy_attention_mask = torch.ones_like(dummy_input)
+
+ # make sure we do left padding
+ dummy_attention_mask[:, :-1] = 0
+ dummy_attention_mask[:, -1:] = 1
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -1876,6 +2390,54 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_padding_right
+ def test_flash_attn_3_generate_padding_right(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask")
+ if dummy_attention_mask is None:
+ dummy_attention_mask = torch.ones_like(dummy_input)
+ # make sure we do right padding
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -1922,6 +2484,52 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_use_cache
+ def test_flash_attn_3_generate_use_cache(self):
+ max_new_tokens = 30
+
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
index 9b34f4dde6594f..a88db2a7578924 100644
--- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
+++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
@@ -35,6 +35,7 @@
is_torch_available,
is_torchaudio_available,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_accelerator,
require_torch_fp16,
@@ -398,6 +399,86 @@ def test_flash_attn_2_inference_equivalence(self):
model.train()
_ = model_fa(dummy_input, **other_inputs)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_3_inference_equivalence
+ def test_flash_attn_3_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+
+ # Ignore copy
+ outputs = model(dummy_input, output_hidden_states=True)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ # Ignore copy
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
+
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(dummy_input, **other_inputs)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -477,6 +558,85 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_3_inference_equivalence_right_padding
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ if model.config.is_encoder_decoder:
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
+
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ else:
+ outputs = model(dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+ # Ignore copy
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -523,6 +683,52 @@ def test_flash_attn_2_generate_left_padding(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_left_padding
+ def test_flash_attn_3_generate_left_padding(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # make sure we do left padding
+ dummy_attention_mask[:, :-1] = 0
+ dummy_attention_mask[:, -1:] = 1
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -569,6 +775,52 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_padding_right
+ def test_flash_attn_3_generate_padding_right(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # make sure we do right padding
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -615,6 +867,52 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_3_generate_use_cache
+ def test_flash_attn_3_generate_use_cache(self):
+ max_new_tokens = 30
+
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@@ -1575,36 +1873,121 @@ def test_greedy_generate_stereo_outputs(self):
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.audio_channels = 2
- model = model_class(config).to(torch_device).eval()
- output_generate = self._greedy_generate(
- model=model,
- input_ids=input_ids.to(torch_device),
- attention_mask=attention_mask.to(torch_device),
- output_scores=True,
- output_hidden_states=True,
- output_attentions=True,
- return_dict_in_generate=True,
- )
+ model = model_class(config).to(torch_device).eval()
+ output_generate = self._greedy_generate(
+ model=model,
+ input_ids=input_ids.to(torch_device),
+ attention_mask=attention_mask.to(torch_device),
+ output_scores=True,
+ output_hidden_states=True,
+ output_attentions=True,
+ return_dict_in_generate=True,
+ )
+
+ self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
+
+ self.assertNotIn(config.pad_token_id, output_generate)
+
+ @unittest.skip(
+ reason="MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model"
+ )
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ @require_flash_attn
+ @require_torch_gpu
+ @mark.flash_attn_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
+ def test_flash_attn_2_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+
+ # Ignore copy
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
+ # Ignore copy
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+ # Ignore copy
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ # Ignore copy
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+ # Ignore copy
+ outputs = model(dummy_input, **other_inputs)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, **other_inputs)
- self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
- self.assertNotIn(config.pad_token_id, output_generate)
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
- @unittest.skip(
- reason="MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model"
- )
- def test_save_load_fast_init_from_base(self):
- pass
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(dummy_input, **other_inputs)
- @require_flash_attn
+ @require_flash_attn_3
@require_torch_gpu
- @mark.flash_attn_test
+ @mark.flash_attn_3_test
@slow
- # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
- def test_flash_attn_2_inference_equivalence(self):
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence
+ def test_flash_attn_3_inference_equivalence(self):
for model_class in self.all_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@@ -1612,7 +1995,7 @@ def test_flash_attn_2_inference_equivalence(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
)
model_fa.to(torch_device)
@@ -1763,6 +2146,88 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence_right_padding
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ # Ignore copy
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
+ # Ignore copy
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ # Ignore copy
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ # Ignore copy
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+ # Ignore copy
+ outputs = model(dummy_input, **other_inputs)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -1812,6 +2277,55 @@ def test_flash_attn_2_generate_left_padding(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_left_padding
+ def test_flash_attn_3_generate_left_padding(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask")
+ if dummy_attention_mask is None:
+ dummy_attention_mask = torch.ones_like(dummy_input)
+
+ # make sure we do left padding
+ dummy_attention_mask[:, :-1] = 0
+ dummy_attention_mask[:, -1:] = 1
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -1860,6 +2374,54 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_padding_right
+ def test_flash_attn_3_generate_padding_right(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask")
+ if dummy_attention_mask is None:
+ dummy_attention_mask = torch.ones_like(dummy_input)
+ # make sure we do right padding
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -1906,6 +2468,52 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_use_cache
+ def test_flash_attn_3_generate_use_cache(self):
+ max_new_tokens = 30
+
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
diff --git a/tests/models/nemotron/test_modeling_nemotron.py b/tests/models/nemotron/test_modeling_nemotron.py
index 4f8f4cc77fe8d0..450939bd0d36a1 100644
--- a/tests/models/nemotron/test_modeling_nemotron.py
+++ b/tests/models/nemotron/test_modeling_nemotron.py
@@ -25,6 +25,7 @@
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_gpu,
@@ -178,6 +179,40 @@ def test_flash_attn_2_equivalence(self):
# nemotron flash attention 2 needs a high tolerance
assert torch.allclose(logits_fa, logits, atol=1e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @is_flaky()
+ @slow
+ def test_flash_attn_3_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(reason="Model does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ dummy_input = dummy_input.to(torch_device)
+ outputs = model(dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = outputs.hidden_states[-1]
+ logits_fa = outputs_fa.hidden_states[-1]
+
+ # nemotron flash attention 3 needs a high tolerance
+ assert torch.allclose(logits_fa, logits, atol=1e-2)
+
@require_torch_gpu
class NemotronIntegrationTest(unittest.TestCase):
diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py
index 95b0b01c0a23d9..fca9c563361edd 100644
--- a/tests/models/phi/test_modeling_phi.py
+++ b/tests/models/phi/test_modeling_phi.py
@@ -24,6 +24,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
slow,
@@ -498,6 +499,43 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertListEqual(output_native, output_fa_2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_3_test
+ @slow
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_3_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->microsoft/phi-1
+ def test_flash_attn_3_generate_padding_right(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ model = PhiForCausalLM.from_pretrained(
+ "microsoft/phi-1",
+ load_in_4bit=True,
+ device_map={"": 0},
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
+
+ texts = ["hi", "Hello this is a very long sentence"]
+
+ tokenizer.padding_side = "right"
+ tokenizer.pad_token = tokenizer.eos_token
+
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
+
+ output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_native = tokenizer.batch_decode(output_native)
+
+ model = PhiForCausalLM.from_pretrained(
+ "microsoft/phi-1", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_3"
+ )
+
+ output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_fa_3 = tokenizer.batch_decode(output_fa_3)
+
+ self.assertListEqual(output_native, output_fa_3)
+
@slow
@require_torch
diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py
index 4d6c432f20424d..59745222cec592 100644
--- a/tests/models/qwen2/test_modeling_qwen2.py
+++ b/tests/models/qwen2/test_modeling_qwen2.py
@@ -25,6 +25,7 @@
backend_empty_cache,
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_sdpa,
@@ -450,6 +451,40 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -495,6 +530,51 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Qwen2 apparently does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -502,6 +582,13 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="Qwen2 flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest(reason="Qwen2 flash attention does not support right padding")
+
@require_torch
class Qwen2IntegrationTest(unittest.TestCase):
@@ -571,6 +658,36 @@ def test_model_450m_long_prompt(self):
backend_empty_cache(torch_device)
gc.collect()
+ @require_bitsandbytes
+ @slow
+ @require_flash_attn_3
+ @pytest.mark.flash_attn_3_test
+ def test_model_450m_long_prompt_fav3(self):
+ EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
+ # An input with 4097 tokens that is above the size of the sliding window
+ input_ids = [1] + [306, 338] * 2048
+ model = Qwen2ForCausalLM.from_pretrained(
+ "Qwen/Qwen2-450m-beta",
+ device_map="auto",
+ load_in_4bit=True,
+ attn_implementation="flash_attention_3",
+ )
+ input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
+ generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
+ self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
+
+ # Assisted generation
+ assistant_model = model
+ assistant_model.generation_config.num_assistant_tokens = 2
+ assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
+ generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
+ self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
+
+ del assistant_model
+ del model
+ backend_empty_cache(torch_device)
+ gc.collect()
+
@slow
@require_torch_sdpa
def test_model_450m_long_prompt_sdpa(self):
diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py
index 0425172a6fba4d..15eb85b5bcc351 100644
--- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py
+++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py
@@ -25,6 +25,7 @@
backend_empty_cache,
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_sdpa,
@@ -475,6 +476,40 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -520,6 +555,51 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Qwen2Moe apparently does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -527,6 +607,13 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="Qwen2Moe flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest(reason="Qwen2Moe flash attention does not support right padding")
+
# Ignore copy
def test_load_balancing_loss(self):
r"""
@@ -633,6 +720,36 @@ def test_model_a2_7b_long_prompt(self):
backend_empty_cache(torch_device)
gc.collect()
+ @require_bitsandbytes
+ @slow
+ @require_flash_attn_3
+ @pytest.mark.flash_attn_3_test
+ def test_model_a2_7b_long_prompt_fav3(self):
+ EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
+ # An input with 4097 tokens that is above the size of the sliding window
+ input_ids = [1] + [306, 338] * 2048
+ model = Qwen2MoeForCausalLM.from_pretrained(
+ "Qwen/Qwen1.5-MoE-A2.7B",
+ device_map="auto",
+ load_in_4bit=True,
+ attn_implementation="flash_attention_3",
+ )
+ input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
+ generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
+ self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
+
+ # Assisted generation
+ assistant_model = model
+ assistant_model.generation_config.num_assistant_tokens = 2
+ assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
+ generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
+ self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
+
+ del assistant_model
+ del model
+ backend_empty_cache(torch_device)
+ gc.collect()
+
@slow
@require_torch_sdpa
def test_model_a2_7b_long_prompt_sdpa(self):
diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
index 6b8072fc184f73..fa4750956d728b 100644
--- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
+++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
@@ -28,6 +28,7 @@
)
from transformers.testing_utils import (
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
slow,
@@ -475,6 +476,38 @@ def test_small_model_integration_test_batch_flashatt2(self):
self.processor.batch_decode(output, skip_special_tokens=True)[1],
)
+ @slow
+ @require_flash_attn_3
+ @require_torch_gpu
+ def test_small_model_integration_test_batch_flashatt3(self):
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
+ "Qwen/Qwen2-VL-7B-Instruct",
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_3",
+ device_map="auto",
+ )
+ text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
+ inputs = self.processor(text=[text, text], images=[self.image, self.image], return_tensors="pt").to(
+ torch_device
+ )
+
+ # it should not matter whether two images are the same size or not
+ output = model.generate(**inputs, max_new_tokens=30)
+
+ EXPECTED_DECODED_TEXT = [
+ "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets",
+ "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets",
+ ]
+
+ self.assertEqual(
+ self.processor.batch_decode(output, skip_special_tokens=True),
+ EXPECTED_DECODED_TEXT,
+ )
+ self.assertEqual(
+ self.processor.batch_decode(output, skip_special_tokens=True)[0],
+ self.processor.batch_decode(output, skip_special_tokens=True)[1],
+ )
+
@slow
@require_flash_attn
@require_torch_gpu
@@ -507,3 +540,36 @@ def test_small_model_integration_test_batch_wo_image_flashatt2(self):
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
+
+ @slow
+ @require_flash_attn_3
+ @require_torch_gpu
+ def test_small_model_integration_test_batch_wo_image_flashatt3(self):
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
+ "Qwen/Qwen2-VL-7B-Instruct",
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_3",
+ device_map="auto",
+ )
+ text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
+ messages2 = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Who are you?"},
+ ]
+ text2 = self.processor.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
+ inputs = self.processor(text=[text, text2], images=[self.image], padding=True, return_tensors="pt").to(
+ torch_device
+ )
+
+ # it should not matter whether two images are the same size or not
+ output = model.generate(**inputs, max_new_tokens=30)
+
+ EXPECTED_DECODED_TEXT = [
+ "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets",
+ "system\nYou are a helpful assistant.\nuser\nWho are you?\nassistant\nI am Qwen, a large language model created by Alibaba Cloud. I am designed to answer a wide range of questions and provide information on various topics",
+ ]
+
+ self.assertEqual(
+ self.processor.batch_decode(output, skip_special_tokens=True),
+ EXPECTED_DECODED_TEXT,
+ )
diff --git a/tests/models/siglip/test_modeling_siglip.py b/tests/models/siglip/test_modeling_siglip.py
index 9d1e3109b313c3..a3c1cfcb66e40f 100644
--- a/tests/models/siglip/test_modeling_siglip.py
+++ b/tests/models/siglip/test_modeling_siglip.py
@@ -28,6 +28,7 @@
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from transformers.testing_utils import (
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_sdpa,
@@ -834,12 +835,95 @@ def test_flash_attn_2_inference_equivalence(self):
output_hidden_states=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
+ dummy_input_ids = inputs_dict["input_ids"]
+
+ outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(
+ pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
+ )
+
+ self.assertTrue(
+ torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
+ f"Image logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
+ )
+ self.assertTrue(
+ torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
+ f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
+ )
+
+ # Test with attention mask
+ dummy_attention_mask = inputs_dict["attention_mask"]
+
+ if dummy_attention_mask is not None:
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+
+ outputs = model(
+ pixel_values=dummy_pixel_values,
+ input_ids=dummy_input_ids,
+ attention_mask=dummy_attention_mask,
+ output_hidden_states=True,
+ )
+ outputs_fa = model_fa(
+ pixel_values=dummy_pixel_values,
+ input_ids=dummy_input_ids,
+ attention_mask=dummy_attention_mask,
+ output_hidden_states=True,
+ )
+
+ self.assertTrue(
+ torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
+ f"Logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
+ )
+ self.assertTrue(
+ torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
+ f"Logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
+ )
+
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(
+ pixel_values=dummy_pixel_values,
+ input_ids=dummy_input_ids,
+ attention_mask=dummy_attention_mask,
+ output_hidden_states=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("SigLIP does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest("SigLIP does not support right padding")
+
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py
index 36cad89bcfdf06..5437330bc96a59 100644
--- a/tests/models/stablelm/test_modeling_stablelm.py
+++ b/tests/models/stablelm/test_modeling_stablelm.py
@@ -24,6 +24,7 @@
is_flaky,
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_sdpa,
slow,
@@ -559,6 +560,24 @@ def test_model_3b_long_prompt(self):
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-3:].tolist())
+ @require_bitsandbytes
+ @slow
+ @require_flash_attn_3
+ @pytest.mark.flash_attn_3_test
+ def test_model_3b_long_prompt_fav3(self):
+ EXPECTED_OUTPUT_TOKEN_IDS = [3, 3, 3]
+ input_ids = [306, 338] * 2047
+ model = StableLmForCausalLM.from_pretrained(
+ "stabilityai/stablelm-3b-4e1t",
+ device_map="auto",
+ torch_dtype="auto",
+ load_in_4bit=True,
+ attn_implementation="flash_attention_3",
+ )
+ input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
+ generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
+ self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-3:].tolist())
+
# Copied from transformers.tests.models.llama.test_modeling_llama.LlamaModelTest.test_eager_matches_sdpa_generate with Llama->StableLm,saibo/llama-1B->stabilityai/stablelm-3b-4e1t
# TODO: @Fxmarty
@is_flaky(max_attempts=3, description="flaky on some models.")
diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py
index c1c7d45d4f18d7..04a5f657f833b0 100644
--- a/tests/models/starcoder2/test_modeling_starcoder2.py
+++ b/tests/models/starcoder2/test_modeling_starcoder2.py
@@ -23,6 +23,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
slow,
@@ -431,6 +432,40 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -476,6 +511,51 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Starcoder2 apparently does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -483,6 +563,13 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="Starcoder2 flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest(reason="Starcoder2 flash attention does not support right padding")
+
@slow
@require_torch_gpu
@@ -549,6 +636,28 @@ def test_starcoder2_batched_generation_fa2(self):
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT, output_text)
+ @require_flash_attn_3
+ @pytest.mark.flash_attn_3_test
+ def test_starcoder2_batched_generation_fa3(self):
+ EXPECTED_TEXT = [
+ "Hello my name is Younes and I am a student at the University of Liverpool. I am currently studying for my MSc in Computer Science. I am interested in the field of Machine Learning and I am currently working on",
+ "def hello_world():\n\treturn 'Hello World!'\n\n@app.route('/hello/')\ndef hello_name(name):\n\treturn 'Hello %s!' % name\n\n@app",
+ ]
+ model_id = "bigcode/starcoder2-7b"
+
+ model = Starcoder2ForCausalLM.from_pretrained(
+ model_id, torch_dtype=torch.float16, device_map="auto", attn_implementation="flash_attention_3"
+ )
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ tokenizer.pad_token = tokenizer.eos_token
+
+ text = ["Hello my name is Younes and", "def hello_world():"]
+ inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch_device)
+
+ output = model.generate(**inputs, max_new_tokens=40, do_sample=False)
+ output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
+ self.assertEqual(EXPECTED_TEXT, output_text)
+
@require_bitsandbytes
def test_starcoder2_batched_generation_4bit(self):
EXPECTED_TEXT = [
diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py
index ff7a85218d3a00..1bce68348cc91b 100644
--- a/tests/models/wav2vec2/test_modeling_wav2vec2.py
+++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py
@@ -35,6 +35,7 @@
is_pyctcdecode_available,
is_torchaudio_available,
require_flash_attn,
+ require_flash_attn_3,
require_pyctcdecode,
require_soundfile,
require_torch,
@@ -2023,6 +2024,28 @@ def test_inference_ctc_fa2(self):
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ def test_inference_ctc_fa3(self):
+ model_fa = Wav2Vec2ForCTC.from_pretrained(
+ "facebook/wav2vec2-base-960h", attn_implementation="flash_attention_3", torch_dtype=torch.bfloat16
+ )
+ model_fa.to(torch_device)
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
+ input_speech = self._load_datasamples(1)
+
+ input_values = processor(input_speech, return_tensors="pt").input_values.to(torch_device)
+
+ with torch.no_grad():
+ logits = model_fa(input_values.to(torch.bfloat16)).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ predicted_trans = processor.batch_decode(predicted_ids)
+
+ EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
+ self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -2049,3 +2072,30 @@ def test_inference_ctc_fa2_batched(self):
"sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
+
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ def test_inference_ctc_fa3_batched(self):
+ model_fa = Wav2Vec2ForCTC.from_pretrained(
+ "facebook/wav2vec2-base-960h", attn_implementation="flash_attention_3", torch_dtype=torch.bfloat16
+ )
+ model_fa.to(torch_device)
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
+
+ input_speech = self._load_datasamples(2)
+
+ inputs = processor(input_speech, return_tensors="pt", padding=True, return_attention_mask=True)
+ inputs = inputs.to(torch_device)
+
+ with torch.no_grad():
+ logits = model_fa(inputs.input_values.to(torch.bfloat16), attention_mask=inputs.attention_mask).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ predicted_trans = processor.batch_decode(predicted_ids)
+
+ EXPECTED_TRANSCRIPTIONS = [
+ "a man said to the universe sir i exist",
+ "sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore",
+ ]
+ self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py
index f7ac2bc12eb252..ef28a11dd23b32 100644
--- a/tests/models/whisper/test_modeling_whisper.py
+++ b/tests/models/whisper/test_modeling_whisper.py
@@ -34,6 +34,7 @@
is_flaky,
is_pt_flax_cross_test,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_fp16,
require_torch_gpu,
@@ -987,6 +988,52 @@ def test_flash_attn_2_inference_equivalence(self):
model.train()
_ = model_fa(dummy_input, decoder_input_ids=decoder_input_ids)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence(self):
+ import torch
+
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(reason="Model does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.bfloat16,
+ )
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model.main_input_name][:1]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
+
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+
+ logits = outputs.decoder_hidden_states[-1]
+ logits_fa = outputs_fa.decoder_hidden_states[-1]
+
+ # whisper FA3 needs very high tolerance
+ assert torch.allclose(logits_fa, logits, atol=4e-1)
+
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(dummy_input, decoder_input_ids=decoder_input_ids)
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -1043,6 +1090,62 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
# whisper FA2 needs very high tolerance
assert torch.allclose(logits_fa[:, -2:], logits[:, -2:], atol=4e-1)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ import torch
+
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(reason="Model does not support flash_attention_3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16)
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model.main_input_name][:1]
+ dummy_input = dummy_input.to(torch.float16)
+
+ decoder_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=dummy_input.device, dtype=torch.long)
+ decoder_attention_mask = torch.tensor(
+ [[0, 0, 0, 1, 1, 1]], device=dummy_input.device, dtype=torch.long
+ )
+
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+
+ logits = outputs.decoder_hidden_states[-1]
+ logits_fa = outputs_fa.decoder_hidden_states[-1]
+
+ # whisper FA3 needs very high tolerance
+ assert torch.allclose(logits_fa, logits, atol=4e-1)
+
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": decoder_attention_mask,
+ "output_hidden_states": True,
+ }
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = outputs.decoder_hidden_states[-1]
+ logits_fa = outputs_fa.decoder_hidden_states[-1]
+
+ # whisper FA3 needs very high tolerance
+ assert torch.allclose(logits_fa[:, -2:], logits[:, -2:], atol=4e-1)
+
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
self.skipTest(reason="test_torchscript is set to False")
@@ -1695,6 +1798,59 @@ def test_flash_attn_2_generate_reuse_cache(self):
past_key_values=past_key_values,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_reuse_cache(self):
+ max_new_tokens = 2
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name][..., :10]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # run generate once to get filled cache
+ output = model.generate(
+ dummy_input,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ return_dict_in_generate=True,
+ )
+ past_key_values = output.past_key_values
+
+ # Try to continue generation from where we left, given that we have more than 1 new token to process
+ # e.g. this can happen in speculative decoding when feeding candidate tokens back to target model
+ _ = model.generate(
+ dummy_input,
+ decoder_input_ids=output.sequences,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ past_key_values=past_key_values,
+ )
+
def test_labels_sequence_max_length_correct(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index c7af0b1c9f5b60..5e3f198872e7fe 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -76,6 +76,7 @@
require_accelerate,
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_safetensors,
require_torch,
@@ -3485,6 +3486,34 @@ def test_flash_attn_2_conversion(self):
self.assertTrue(False, "FlashAttention2 modules not found in model")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_conversion(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_3"
+ ).to(torch_device)
+
+ for _, module in model.named_modules():
+ if "FlashAttention" in module.__class__.__name__:
+ return
+
+ self.assertTrue(False, "FlashAttention3 modules not found in model")
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -3582,6 +3611,103 @@ def test_flash_attn_2_inference_equivalence(self):
model.train()
_ = model_fa(dummy_input, **other_inputs)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ @is_flaky()
+ def test_flash_attn_3_inference_equivalence(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model.main_input_name][:1]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ dummy_attention_mask = dummy_attention_mask[:1]
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+
+ if model.config.is_encoder_decoder:
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
+
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ else:
+ outputs = model(dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ if model.config.is_encoder_decoder:
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+ else:
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
+
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(dummy_input, **other_inputs)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -3675,6 +3801,99 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ @is_flaky()
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model.main_input_name][:1]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ dummy_attention_mask = dummy_attention_mask[:1]
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ if model.config.is_encoder_decoder:
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
+
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ else:
+ outputs = model(dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ if model.config.is_encoder_decoder:
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+ else:
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -3723,18 +3942,114 @@ def test_flash_attn_2_generate_left_padding(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ @is_flaky()
+ def test_flash_attn_3_generate_left_padding(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # make sure we do left padding
+ dummy_attention_mask[:, :-1] = 0
+ dummy_attention_mask[:, -1:] = 1
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@is_flaky()
@slow
- def test_flash_attn_2_generate_padding_right(self):
+ def test_flash_attn_2_generate_padding_right(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # make sure we do right padding
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_2",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @is_flaky()
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@@ -3761,7 +4076,7 @@ def test_flash_attn_2_generate_padding_right(self):
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
+ attn_implementation="flash_attention_3",
low_cpu_mem_usage=True,
).to(torch_device)
@@ -4349,6 +4664,65 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
+ # Generate with one batch only to test generation when attention mask will be None
+ # when real inputs are used, because there is no padding. See issue #32237 for more
+ dummy_input = dummy_input[:1, ...]
+ dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...])
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -4405,6 +4779,62 @@ def test_flash_attn_2_generate_reuse_cache(self):
past_key_values=past_key_values,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_reuse_cache(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ max_new_tokens = 2
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # run generate once to get filled cache
+ output = model.generate(
+ dummy_input,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ return_dict_in_generate=True,
+ )
+ past_key_values = output.past_key_values
+
+ # Try to continue generation from where we left, given that we have more than 1 new token to process
+ # e.g. this can happen in speculative decoding when feeding candidate tokens back to target model
+ dummy_input_updated = torch.cat([dummy_input, output.sequences], dim=-1)
+ _ = model.generate(
+ dummy_input_updated,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ past_key_values=past_key_values,
+ )
+
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@@ -4462,6 +4892,63 @@ def test_flash_attn_2_fp32_ln(self):
# with attention mask
_ = model(dummy_input, attention_mask=dummy_attention_mask)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_fp32_ln(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_input = inputs_dict[model.main_input_name]
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ batch_size = dummy_attention_mask.shape[0]
+
+ is_padding_right = dummy_attention_mask[:, -1].sum().item() != batch_size
+
+ # To avoid errors with padding_side=="right"
+ if is_padding_right:
+ dummy_attention_mask = torch.ones_like(dummy_input)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ load_in_4bit=True,
+ )
+
+ for _, param in model.named_parameters():
+ # upcast only layer norms
+ if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
+ param.data = param.data.to(torch.float32)
+
+ if model.config.is_encoder_decoder:
+ dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
+ dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
+
+ _ = model(dummy_input, decoder_input_ids=dummy_decoder_input_ids)
+ # with attention mask
+ _ = model(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ decoder_input_ids=dummy_decoder_input_ids,
+ decoder_attention_mask=dummy_decoder_attention_mask,
+ )
+ else:
+ _ = model(dummy_input)
+ # with attention mask
+ _ = model(dummy_input, attention_mask=dummy_attention_mask)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -4535,6 +5022,79 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, atol=tol, rtol=tol)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
+ self.skipTest("Model dummy inputs should contain padding in their attention mask")
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ # ensure left padding, to adapt for some models
+ if 0 in inputs_dict["attention_mask"][:, -1]:
+ inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
+ dummy_attention_mask = inputs_dict["attention_mask"]
+ inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.pad_token_id
+
+ model = (
+ model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ )
+ .to(torch_device)
+ .eval()
+ )
+
+ # flatten
+ padfree_inputs_dict = {
+ k: v[dummy_attention_mask.bool()].unsqueeze(0)
+ for k, v in inputs_dict.items()
+ if not k == "attention_mask"
+ }
+ # add position_ids
+ padfree_inputs_dict["position_ids"] = (
+ torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
+ .long()
+ .unsqueeze(0)
+ .to(torch_device)
+ )
+
+ res_padded = model(**inputs_dict)
+ res_padfree = model(**padfree_inputs_dict)
+
+ logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
+ logits_padfree = res_padfree.logits[0]
+
+ torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), atol=0, rtol=0)
+ # acceptable numerical instability
+ tol = torch.finfo(torch.float16).eps
+ torch.testing.assert_close(logits_padded, logits_padfree, atol=tol, rtol=tol)
+
@is_pt_tf_cross_test
def test_tf_from_pt_safetensors(self):
for model_class in self.all_model_classes:
@@ -4630,6 +5190,54 @@ def test_flash_attn_2_from_config(self):
self.assertFalse(fa2_correctly_converted)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_from_config(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ # TODO: to change it in the future with other relevant auto classes
+ fa3_model = AutoModelForCausalLM.from_config(
+ config, attn_implementation="flash_attention_3", torch_dtype=torch.bfloat16
+ ).to(torch_device)
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
+
+ fa3_correctly_converted = False
+
+ for _, module in fa3_model.named_modules():
+ if "FlashAttention" in module.__class__.__name__:
+ fa3_correctly_converted = True
+ break
+
+ self.assertTrue(fa3_correctly_converted)
+
+ _ = fa3_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ fa3_model.save_pretrained(tmpdirname)
+
+ model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname)
+
+ self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_3")
+
+ fa3_correctly_converted = False
+
+ for _, module in model_from_pretrained.named_modules():
+ if "FlashAttention" in module.__class__.__name__:
+ fa3_correctly_converted = True
+ break
+
+ self.assertFalse(fa3_correctly_converted)
+
def _get_custom_4d_mask_test_data(self):
# Sequence in which all but the last token is the same
input_ids = torch.tensor(
diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py
index f78285fdb90d90..3799d7c15298ca 100644
--- a/tests/utils/test_modeling_utils.py
+++ b/tests/utils/test_modeling_utils.py
@@ -67,6 +67,7 @@
)
from transformers.utils.import_utils import (
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flax_available,
is_tf_available,
is_torch_sdpa_available,
@@ -549,10 +550,14 @@ def test_model_from_pretrained_attn_implementation(self):
if is_flash_attn_2_available():
attn_implementation_available.append("flash_attention_2")
+ if is_flash_attn_3_available():
+ attn_implementation_available.append("flash_attention_3")
+
mistral_attention_classes = {
"eager": "MistralAttention",
"sdpa": "MistralSdpaAttention",
"flash_attention_2": "MistralFlashAttention2",
+ "flash_attention_3": "MistralFlashAttention3",
}
for requested_attn_implementation in attn_implementation_available:
model = AutoModelForCausalLM.from_pretrained(
@@ -588,10 +593,14 @@ def test_model_from_config_attn_implementation(self):
if is_flash_attn_2_available():
attn_implementation_available.append("flash_attention_2")
+ if is_flash_attn_3_available():
+ attn_implementation_available.append("flash_attention_3")
+
mistral_attention_classes = {
"eager": "MistralAttention",
"sdpa": "MistralSdpaAttention",
"flash_attention_2": "MistralFlashAttention2",
+ "flash_attention_3": "MistralFlashAttention3",
}
for requested_attn_implementation in attn_implementation_available:
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
@@ -2439,6 +2448,14 @@ def test_error_no_flash_available(self):
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
+ def test_error_no_flash_3_available(self):
+ with self.assertRaises(ValueError) as cm:
+ _ = AutoModel.from_pretrained(
+ "hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="flash_attention_3"
+ )
+
+ self.assertTrue("does not support Flash Attention 3.0" in str(cm.exception))
+
def test_error_no_flash_available_with_config(self):
with self.assertRaises(ValueError) as cm:
config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
@@ -2449,6 +2466,16 @@ def test_error_no_flash_available_with_config(self):
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
+ def test_error_no_flash_3_available_with_config(self):
+ with self.assertRaises(ValueError) as cm:
+ config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
+
+ _ = AutoModel.from_pretrained(
+ "hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_3"
+ )
+
+ self.assertTrue("does not support Flash Attention 3.0" in str(cm.exception))
+
def test_error_wrong_attn_implementation(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
@@ -2465,6 +2492,16 @@ def test_not_available_flash(self):
)
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
+ def test_not_available_flash_3(self):
+ if is_flash_attn_3_available():
+ self.skipTest(reason="Please uninstall flash_attn_interface package to run test_not_available_flash_3")
+
+ with self.assertRaises(ImportError) as cm:
+ _ = AutoModel.from_pretrained(
+ "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_3"
+ )
+ self.assertTrue("the package flash_attn_interface seems to be not installed" in str(cm.exception))
+
def test_not_available_flash_with_config(self):
if is_flash_attn_2_available():
self.skipTest(reason="Please uninstall flash-attn package to run test_not_available_flash")
@@ -2480,6 +2517,23 @@ def test_not_available_flash_with_config(self):
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
+ def test_not_available_flash_3_with_config(self):
+ if is_flash_attn_3_available():
+ self.skipTest(
+ reason="Please uninstall flash_attn_interface package to run test_not_available_flash_3_with_config"
+ )
+
+ config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")
+
+ with self.assertRaises(ImportError) as cm:
+ _ = AutoModel.from_pretrained(
+ "hf-internal-testing/tiny-random-GPTBigCodeModel",
+ config=config,
+ attn_implementation="flash_attention_3",
+ )
+
+ self.assertTrue("the package flash_attn_interface seems to be not installed" in str(cm.exception))
+
def test_not_available_sdpa(self):
if is_torch_sdpa_available():
self.skipTest(reason="This test requires torch<=2.0")