diff --git a/.github/workflows/push-important-models.yml b/.github/workflows/push-important-models.yml
index 41bcd43fcc6fc2..17c1b6c86fb066 100644
--- a/.github/workflows/push-important-models.yml
+++ b/.github/workflows/push-important-models.yml
@@ -87,6 +87,11 @@ jobs:
run:
pytest -rsfE -m "flash_attn_test" --make-reports=${{ matrix.model-name }}_fa2_tests/ tests/${{ matrix.model-name }}/test_modeling_*
+ - name: Run FA3 tests
+ id: run_fa3_tests
+ run:
+ pytest -rsfE -m "flash_attn_3_test" --make-reports=${{ matrix.model-name }}_fa3_tests/ tests/${{ matrix.model-name }}/test_modeling_*
+
- name: "Test suite reports artifacts: ${{ matrix.model-name }}_fa2_tests"
if: ${{ always() }}
uses: actions/upload-artifact@v4
diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md
index 16be638498dfd4..4357e96735cd6d 100644
--- a/docs/source/en/llm_optims.md
+++ b/docs/source/en/llm_optims.md
@@ -348,6 +348,24 @@ model = AutoModelForCausalLM.from_pretrained(
)
```
+### FlashAttention-3
+
+FlashAttention and [FlashAttention-3](./perf_infer_gpu_one#flashattention-3) break up the attention computation into smaller chunks and reduces the number of intermediate read/write operations to GPU memory to speed up inference. FlashAttention-3 improves on FlashAttention-2 algorithm by taking advantage of new features on Hopper GPUs to maximize performance.
+
+To use FlashAttention-3, set `attn_implementation="flash_attention_3"` in the [`~PreTrainedModel.from_pretrained`] method.
+
+```py
+from transformers import AutoModelForCausalLM, BitsAndBytesConfig
+
+quant_config = BitsAndBytesConfig(load_in_8bit=True)
+model = AutoModelForCausalLM.from_pretrained(
+ "google/gemma-2b",
+ quantization_config=quant_config,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_3",
+)
+```
+
### PyTorch scaled dot product attention
Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation.
diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md
index ed3b26029d0094..7ad73ac91165af 100644
--- a/docs/source/en/perf_infer_gpu_one.md
+++ b/docs/source/en/perf_infer_gpu_one.md
@@ -201,6 +201,142 @@ FlashAttention is more memory efficient, meaning you can train on much larger se
+
+## FlashAttention-3
+
+
+
+FlashAttention-3 is experimental and may change considerably in future versions.
+
+
+
+[FlashAttention-3](https://huggingface.co/papers/2407.08608) improves on FlashAttention-2 algorithm by taking advantage of new features on Hopper GPUs to maximize performance:
+
+1. overlap overall computation and data movement via warp-specialization
+2. interleave block-wise matmul and softmax operations
+3. block quantization and incoherent processing that leverages hardware support for FP8 low-precision
+
+FlashAttention-3 is currently supported for the following architectures:
+* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel)
+* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel)
+* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon)
+* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel)
+* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel)
+* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel)
+* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel)
+* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel)
+* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model)
+* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2)
+* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel)
+* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel)
+* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel)
+* [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel)
+* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel)
+* [GraniteMoe](https://huggingface.co/docs/transformers/model_doc/granitemoe#transformers.GraniteMoeModel)
+* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model)
+* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel)
+* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel)
+* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel)
+* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel)
+* [Llava](https://huggingface.co/docs/transformers/model_doc/llava)
+* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next)
+* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video)
+* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision)
+* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi)
+* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava)
+* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava)
+* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100)
+* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel)
+* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel)
+* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel)
+* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel)
+* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel)
+* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron)
+* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)
+* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel)
+* [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel)
+* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel)
+* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel)
+* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model)
+* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel)
+* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model)
+* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model)
+* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder)
+* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel)
+* [Qwen2VL](https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel)
+* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel)
+* [Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model)
+* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel)
+* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel)
+* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel)
+* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip)
+* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel)
+* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel)
+
+You can request to add FlashAttention-3 support for another model by opening a GitHub Issue or Pull Request.
+
+Before you begin, make sure you have FlashAttention-3 installed.
+
+
+
+
+```bash
+git clone https://github.com/Dao-AILab/flash-attention
+cd flash-attention/hopper
+python setup.py install
+```
+
+
+
+
+To enable FlashAttention-3, pass the argument `attn_implementation="flash_attention_3"` to [`~AutoModelForCausalLM.from_pretrained`]:
+
+```python
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
+
+model_id = "tiiuae/falcon-7b"
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_3",
+)
+```
+
+
+
+FlashAttention-3 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-3.
+
+
+
+
+
+FlashAttention-3 can be combined with other optimization techniques like quantization to further speedup inference. For example, you can combine FlashAttention-3 with 8-bit or 4-bit quantization:
+
+```py
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
+
+model_id = "tiiuae/falcon-7b"
+tokenizer = AutoTokenizer.from_pretrained(model_id)
+
+# load in 8bit
+model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ load_in_8bit=True,
+ attn_implementation="flash_attention_3",
+)
+
+# load in 4bit
+model = AutoModelForCausalLM.from_pretrained(
+ model_id,
+ load_in_4bit=True,
+ attn_implementation="flash_attention_3",
+)
+```
+
## PyTorch scaled dot product attention
PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
diff --git a/pyproject.toml b/pyproject.toml
index bf78e0174394f5..fdcccd96884caa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -49,6 +49,7 @@ addopts = "--doctest-glob='**/*.md'"
doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS"
markers = [
"flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')",
+ "flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')",
"bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests",
"generate: marks tests that use the GenerationTesterMixin"
]
diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py
index 44e61825dd9cd6..a0ee24e3085437 100644
--- a/src/transformers/modeling_flash_attention_utils.py
+++ b/src/transformers/modeling_flash_attention_utils.py
@@ -20,7 +20,7 @@
import torch
import torch.nn.functional as F
-from .utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal
+from .utils import is_flash_attn_2_available, is_flash_attn_3_available, is_flash_attn_greater_or_equal
if is_flash_attn_2_available():
@@ -29,6 +29,11 @@
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
+if is_flash_attn_3_available():
+ from flash_attn_interface import _flash_attn_forward as _flash_attn_3_forward
+ from flash_attn_interface import flash_attn_func as flash_attn_3_func
+ from flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_3_func
+
def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""
@@ -194,6 +199,7 @@ def _flash_attention_forward(
use_top_left_mask: bool = False,
softcap: Optional[float] = None,
deterministic: bool = None,
+ use_flash_attn_3: bool = False,
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
@@ -219,18 +225,34 @@ def _flash_attention_forward(
Softcap for the attention logits, used e.g. in gemma2.
deterministic (`bool`, *optional*):
Determines if the deterministic option introduced in flash_attn>=2.4.1 is enabled.
+ use_flash_attn_3 (`bool`, *optional*):
+ Determines if Flash Attention v3 should be used.
"""
if not use_top_left_mask:
causal = is_causal
else:
- # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__.
+ # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__.
causal = is_causal and query_length != 1
+ # FAv3 FP8 path uses `_flash_attn_3_forward` which doesn't set the default value.
+ softmax_scale = softmax_scale or query_states.shape[-1] ** (-0.5)
+
+ use_fp8 = use_flash_attn_3 and os.environ.get("FLASH_ATTENTION_3_FP8", "0") == "1"
+
+ flash_kwargs = {}
# Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length).
- use_sliding_windows = (
- _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
- )
- flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
+ use_sliding_windows = sliding_window is not None and key_states.shape[1] > sliding_window
+ if not _flash_supports_window_size:
+ flash_kwargs = {}
+ elif use_sliding_windows:
+ flash_kwargs["window_size"] = (sliding_window, sliding_window)
+ else:
+ # Needs default values for FP8 `_flash_attn_forward` path.
+ # Can be removed when FP8 bwd is supported and FP8 path uses `flash_attn_func`.
+ flash_kwargs["window_size"] = (-1, -1)
+
+ if not use_flash_attn_3:
+ flash_kwargs["dropout_p"] = dropout
if is_flash_attn_greater_or_equal("2.4.1"):
if deterministic is None:
@@ -249,7 +271,9 @@ def _flash_attention_forward(
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
- attn_output_unpad = flash_attn_varlen_func(
+ func = flash_attn_varlen_3_func if use_flash_attn_3 else flash_attn_varlen_func
+
+ attn_output_unpad = func(
query_states,
key_states,
value_states,
@@ -257,11 +281,12 @@ def _flash_attention_forward(
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
+ if use_flash_attn_3:
+ attn_output_unpad = attn_output_unpad[0]
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
@@ -276,7 +301,9 @@ def _flash_attention_forward(
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
- attn_output = flash_attn_varlen_func(
+ func = flash_attn_varlen_3_func if use_flash_attn_3 else flash_attn_varlen_func
+
+ attn_output = func(
query_states,
key_states,
value_states,
@@ -284,17 +311,33 @@ def _flash_attention_forward(
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
- dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
**flash_kwargs,
)
+ if use_flash_attn_3:
+ attn_output = attn_output[0]
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
else:
- attn_output = flash_attn_func(
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
- )
+ if use_fp8:
+ # NOTE: uses `_flash_attn_forward` instead of `flash_attn_func` because no bwd for FP8 yet.
+ # `deterministic` is part of `flash_attn_func`/`FlashAttnFunc`, used in bwd, so not present here.
+ attn_output = _flash_attn_3_forward(
+ query_states.to(torch.float8_e4m3fn),
+ key_states.to(torch.float8_e4m3fn),
+ value_states.to(torch.float8_e4m3fn),
+ softmax_scale=softmax_scale,
+ causal=causal,
+ window_size=flash_kwargs["window_size"],
+ )[0]
+ else:
+ func = flash_attn_3_func if use_flash_attn_3 else flash_attn_func
+ attn_output = func(
+ query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs
+ )
+ if use_flash_attn_3:
+ attn_output = attn_output[0]
return attn_output
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index b3250dbb82b1d8..65989b23492d2b 100755
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -82,6 +82,7 @@
is_accelerate_available,
is_bitsandbytes_available,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_offline_mode,
is_optimum_available,
is_peft_available,
@@ -1382,6 +1383,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Flash Attention 2 support
_supports_flash_attn_2 = False
+ # Flash Attention 3 support
+ _supports_flash_attn_3 = False
+
# SDPA support
_supports_sdpa = False
@@ -1565,10 +1569,12 @@ def _autoset_attn_implementation(
' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.'
)
- if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]:
+ if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2", "flash_attention_3"]:
message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)'
if cls._supports_flash_attn_2:
message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)'
+ if cls._supports_flash_attn_3:
+ message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)'
if cls._supports_sdpa:
message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)'
raise ValueError(message + ".")
@@ -1590,6 +1596,14 @@ def _autoset_attn_implementation(
hard_check_only=False,
check_device_map=check_device_map,
)
+ elif config._attn_implementation == "flash_attention_3":
+ cls._check_and_enable_flash_attn_3(
+ config,
+ torch_dtype=torch_dtype,
+ device_map=device_map,
+ hard_check_only=False,
+ check_device_map=check_device_map,
+ )
elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available():
# use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif.
config = cls._check_and_enable_sdpa(
@@ -1778,6 +1792,91 @@ def _check_and_enable_flash_attn_2(
config._attn_implementation = "flash_attention_2"
return config
+ @classmethod
+ def _check_and_enable_flash_attn_3(
+ cls,
+ config,
+ torch_dtype: Optional[torch.dtype] = None,
+ device_map: Optional[Union[str, Dict[str, int]]] = None,
+ check_device_map: bool = True,
+ hard_check_only: bool = False,
+ ) -> PretrainedConfig:
+ """
+ Checks the availability of Flash Attention 3 and compatibility with the current model.
+
+ If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module.
+ """
+ if not cls._supports_flash_attn_3:
+ raise ValueError(
+ f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where"
+ f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new"
+ " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
+ )
+
+ if not is_flash_attn_3_available():
+ preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:"
+ # TODO: docs
+ install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-3 to install Flash Attention 3."
+
+ if importlib.util.find_spec("flash_attn_interface") is None:
+ raise ImportError(
+ f"{preface} the package flash_attn_interface seems to be not installed. {install_message}"
+ )
+
+ if torch.version.cuda:
+ compute_capability = torch.cuda.get_device_capability()
+ major, _ = compute_capability
+ if major < 9:
+ raise ValueError("Flash Attention 3 requires NVIDIA GPU with compute capability >= 9.0")
+ else:
+ raise ValueError("Flash Attention 3 requires NVIDIA GPU with compute capability >= 9.0")
+
+ _is_bettertransformer = getattr(cls, "use_bettertransformer", False)
+
+ if _is_bettertransformer:
+ raise ValueError(
+ "Flash Attention 3 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
+ )
+
+ if torch_dtype is None:
+ logger.warning_once(
+ "You are attempting to use Flash Attention 3.0 without specifying a torch dtype. This might lead to unexpected behaviour"
+ )
+ elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]:
+ logger.warning_once(
+ "Flash Attention 3.0 only supports torch.float16 and torch.bfloat16 dtypes, but"
+ f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
+ ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`'
+ )
+
+ # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called,
+ # or the model may be initialized under the context manager `with torch.device("cuda"):`.
+ if check_device_map and device_map is None and torch.empty(0).device.type != "cuda":
+ if torch.cuda.is_available():
+ logger.warning_once(
+ "You are attempting to use Flash Attention 3.0 with a model not initialized on GPU. Make sure to move the model to GPU"
+ " after initializing it on CPU with `model.to('cuda')`."
+ )
+ else:
+ raise ValueError(
+ "You are attempting to use Flash Attention 3.0 with a model not initialized on GPU and with no GPU available. "
+ "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
+ "or initialising the model on CPU and then moving it to GPU."
+ )
+ elif (
+ check_device_map
+ and device_map is not None
+ and isinstance(device_map, dict)
+ and ("cpu" in device_map.values() or "disk" in device_map.values())
+ ):
+ raise ValueError(
+ "You are attempting to use Flash Attention 3.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to "
+ "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys."
+ )
+ if not hard_check_only:
+ config._attn_implementation = "flash_attention_3"
+ return config
+
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig:
"""
diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py
index 84e368379e292a..e2592de6ee68c6 100755
--- a/src/transformers/models/altclip/modeling_altclip.py
+++ b/src/transformers/models/altclip/modeling_altclip.py
@@ -725,7 +725,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class AltCLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config):
+ def __init__(self, config: AltCLIPConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py
index ea420482379d7a..ff6ff37746c394 100644
--- a/src/transformers/models/bark/modeling_bark.py
+++ b/src/transformers/models/bark/modeling_bark.py
@@ -69,9 +69,9 @@ class BarkSelfAttention(nn.Module):
# adapted from GPTNeoSelfAttention and Bark code
# BarkSelfAttention can have two attention type, i.e full attention or causal attention
- def __init__(self, config, is_causal=False):
+ def __init__(self, config: BarkConfig, is_causal=False):
super().__init__()
-
+ self.config = config
# regularization
self.dropout = config.dropout
self.attn_dropout = nn.Dropout(config.dropout)
@@ -190,14 +190,14 @@ def forward(
return outputs
-class BarkSelfFlashAttention2(BarkSelfAttention):
+class BarkSelfFlashAttention(BarkSelfAttention):
"""
Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -205,6 +205,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def _split_heads(self, tensor, num_heads, attn_head_size):
"""
@@ -266,6 +267,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
@@ -282,7 +284,8 @@ def forward(
BARK_ATTENTION_CLASSES = {
"eager": BarkSelfAttention,
- "flash_attention_2": BarkSelfFlashAttention2,
+ "flash_attention_2": BarkSelfFlashAttention,
+ "flash_attention_3": BarkSelfFlashAttention,
}
@@ -377,6 +380,7 @@ class BarkPreTrainedModel(PreTrainedModel):
config_class = BarkConfig
supports_gradient_checkpointing = False
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
def _init_weights(self, module):
"""Initialize the weights."""
@@ -562,6 +566,7 @@ def __init__(self, config):
self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias)
@@ -703,7 +708,7 @@ def forward(
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
attention_mask = attention_mask if 0 in attention_mask else None
else:
attention_mask = attention_mask.view(batch_size, -1)
@@ -1158,6 +1163,7 @@ def __init__(self, config):
self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self.layernorm_final = nn.LayerNorm(config.hidden_size)
@@ -1342,7 +1348,7 @@ def forward(
if attention_mask is not None:
if batch_size <= 0:
raise ValueError("batch_size has to be defined and > 0")
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
attention_mask = attention_mask if 0 in attention_mask else None
else:
# [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length]
@@ -1818,3 +1824,39 @@ def _check_and_enable_flash_attn_2(
config.coarse_acoustics_config._attn_implementation = config._attn_implementation
config.fine_acoustics_config._attn_implementation = config._attn_implementation
return config
+
+ @classmethod
+ def _check_and_enable_flash_attn_3(
+ cls,
+ config,
+ torch_dtype: Optional[torch.dtype] = None,
+ device_map: Optional[Union[str, Dict[str, int]]] = None,
+ hard_check_only: bool = False,
+ check_device_map: bool = False,
+ ):
+ """
+ `_check_and_enable_flash_attn_3` originally don't expand flash attention enabling to the model
+ sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention
+ if necessary.
+
+ If you don't know about Flash Attention, check out the official repository of flash attention:
+ https://github.com/Dao-AILab/flash-attention
+
+ For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this
+ specific section of the documentation to learn more about it:
+ https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models
+
+ The method checks if the current setup is compatible with Flash Attention as it requires the model to be in
+ half precision and not ran on CPU.
+
+ If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_3" so that the model
+ can initialize the correct attention module
+ """
+ config = super()._check_and_enable_flash_attn_3(
+ config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map
+ )
+
+ config.semantic_config._attn_implementation = config._attn_implementation
+ config.coarse_acoustics_config._attn_implementation = config._attn_implementation
+ config.fine_acoustics_config._attn_implementation = config._attn_implementation
+ return config
diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py
index ac10189ecf5b58..416e6d31f852a6 100755
--- a/src/transformers/models/bart/modeling_bart.py
+++ b/src/transformers/models/bart/modeling_bart.py
@@ -287,14 +287,14 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-class BartFlashAttention2(BartAttention):
+class BartFlashAttention(BartAttention):
"""
Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -302,6 +302,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
@@ -403,6 +404,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
@@ -523,7 +525,8 @@ def forward(
BART_ATTENTION_CLASSES = {
"eager": BartAttention,
"sdpa": BartSdpaAttention,
- "flash_attention_2": BartFlashAttention2,
+ "flash_attention_2": BartFlashAttention,
+ "flash_attention_3": BartFlashAttention,
}
@@ -749,6 +752,7 @@ class BartPreTrainedModel(PreTrainedModel):
_no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
@@ -980,6 +984,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
)
self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
@@ -1068,7 +1073,7 @@ def forward(
# expand attention_mask
if attention_mask is not None:
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
attention_mask = attention_mask if 0 in attention_mask else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
@@ -1164,6 +1169,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
)
self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
@@ -1284,7 +1290,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input)
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
@@ -1304,7 +1310,7 @@ def forward(
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py
index a5ac476c99b3db..118f82f8fe69e9 100644
--- a/src/transformers/models/bloom/modeling_bloom.py
+++ b/src/transformers/models/bloom/modeling_bloom.py
@@ -741,7 +741,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py
index ca02dac7708a7e..4e833792ccf295 100644
--- a/src/transformers/models/chameleon/modeling_chameleon.py
+++ b/src/transformers/models/chameleon/modeling_chameleon.py
@@ -366,9 +366,9 @@ def forward(
return attn_output, attn_weights, past_key_value
-# copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon
+# copied from transformers.models.llama.modeling_llama.LlamaFlashAttention with Llama->Chameleon
# TODO(joao): add me back asap :)
-class ChameleonFlashAttention2(ChameleonAttention):
+class ChameleonFlashAttention(ChameleonAttention):
"""
Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -382,6 +382,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
# Ignore copy
def forward(
@@ -474,6 +475,7 @@ def forward(
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -580,7 +582,8 @@ def forward(
CHAMELEON_ATTENTION_CLASSES = {
"eager": ChameleonAttention,
- "flash_attention_2": ChameleonFlashAttention2,
+ "flash_attention_2": ChameleonFlashAttention,
+ "flash_attention_3": ChameleonFlashAttention,
"sdpa": ChameleonSdpaAttention,
}
@@ -1102,6 +1105,7 @@ class ChameleonPreTrainedModel(PreTrainedModel):
_no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"]
_skip_keys_device_placement = ["past_key_values", "causal_mask"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_quantized_cache = True
_supports_cache_class = True
@@ -1380,7 +1384,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py
index 370f17f479650a..71580efd48d712 100644
--- a/src/transformers/models/clip/modeling_clip.py
+++ b/src/transformers/models/clip/modeling_clip.py
@@ -293,7 +293,7 @@ def forward(
class CLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config):
+ def __init__(self, config: CLIPConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@@ -394,14 +394,14 @@ def forward(
return attn_output, attn_weights_reshaped
-class CLIPFlashAttention2(CLIPAttention):
+class CLIPFlashAttention(CLIPAttention):
"""
CLIPAttention flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -409,8 +409,9 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
- # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
@@ -470,6 +471,7 @@ def forward(
dropout=dropout_rate,
is_causal=causal_attention_mask is not None,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
@@ -557,7 +559,8 @@ def forward(
CLIP_ATTENTION_CLASSES = {
"eager": CLIPAttention,
"sdpa": CLIPSdpaAttention,
- "flash_attention_2": CLIPFlashAttention2,
+ "flash_attention_2": CLIPFlashAttention,
+ "flash_attention_3": CLIPFlashAttention,
}
@@ -637,6 +640,7 @@ class CLIPPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_sdpa = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
def _init_weights(self, module):
"""Initialize the weights"""
@@ -910,6 +914,7 @@ def __init__(self, config: CLIPTextConfig):
# For attention mask, it differs between `flash_attention_2` and other attention implementations
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
@@ -947,7 +952,7 @@ def forward(
)
# expand attention_mask
- if attention_mask is not None and not self._use_flash_attention_2:
+ if attention_mask is not None and not self._use_flash_attention_2 and not self._use_flash_attention_3:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py
index 90520524fa8843..c0d91ad396a071 100644
--- a/src/transformers/models/clipseg/modeling_clipseg.py
+++ b/src/transformers/models/clipseg/modeling_clipseg.py
@@ -261,7 +261,7 @@ def forward(
class CLIPSegAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config):
+ def __init__(self, config: CLIPSegConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py
index 32c265c421d93e..37b63186f01593 100644
--- a/src/transformers/models/codegen/modeling_codegen.py
+++ b/src/transformers/models/codegen/modeling_codegen.py
@@ -584,7 +584,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py
index 7301f434f7fb29..2d795b824084f4 100644
--- a/src/transformers/models/cohere/modeling_cohere.py
+++ b/src/transformers/models/cohere/modeling_cohere.py
@@ -363,8 +363,8 @@ def forward(
return attn_output, attn_weights, past_key_value
-# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere
-class CohereFlashAttention2(CohereAttention):
+# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention with Llama->Cohere
+class CohereFlashAttention(CohereAttention):
"""
Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -378,6 +378,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
# Ignore copy
def forward(
@@ -475,6 +476,7 @@ def forward(
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -591,7 +593,8 @@ def forward(
COHERE_ATTENTION_CLASSES = {
"eager": CohereAttention,
- "flash_attention_2": CohereFlashAttention2,
+ "flash_attention_2": CohereFlashAttention,
+ "flash_attention_3": CohereFlashAttention,
"sdpa": CohereSdpaAttention,
}
@@ -697,6 +700,7 @@ class CoherePreTrainedModel(PreTrainedModel):
_no_split_modules = ["CohereDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -954,7 +958,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py
index b6ad74e8c80f1b..6229f5e1dc3d2a 100755
--- a/src/transformers/models/data2vec/modeling_data2vec_audio.py
+++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py
@@ -50,6 +50,7 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
logger = logging.get_logger(__name__)
@@ -480,15 +481,15 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Data2VecAudio
-class Data2VecAudioFlashAttention2(Data2VecAudioAttention):
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->Data2VecAudio
+class Data2VecAudioFlashAttention(Data2VecAudioAttention):
"""
Data2VecAudio flash attention module. This module inherits from `Data2VecAudioAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -496,6 +497,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
@@ -597,6 +599,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
@@ -718,7 +721,8 @@ def forward(
DATA2VEC2AUDIO_ATTENTION_CLASSES = {
"eager": Data2VecAudioAttention,
"sdpa": Data2VecAudioSdpaAttention,
- "flash_attention_2": Data2VecAudioFlashAttention2,
+ "flash_attention_2": Data2VecAudioFlashAttention,
+ "flash_attention_3": Data2VecAudioFlashAttention,
}
@@ -794,6 +798,7 @@ def __init__(self, config):
self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -810,7 +815,7 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -931,6 +936,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py
index d197722f5b18f0..ac1e3f3f358bf8 100644
--- a/src/transformers/models/dbrx/modeling_dbrx.py
+++ b/src/transformers/models/dbrx/modeling_dbrx.py
@@ -41,6 +41,7 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DbrxConfig"
@@ -310,7 +311,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-class DbrxFlashAttention2(DbrxAttention):
+class DbrxFlashAttention(DbrxAttention):
"""Dbrx flash attention module.
This module inherits from `DbrxAttention` as the weights of the module stays
@@ -318,7 +319,7 @@ class DbrxFlashAttention2(DbrxAttention):
calls the public API of flash attention.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -326,6 +327,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -420,6 +422,7 @@ def forward(
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -528,7 +531,8 @@ def forward(
DBRX_ATTENTION_CLASSES = {
"eager": DbrxAttention,
- "flash_attention_2": DbrxFlashAttention2,
+ "flash_attention_2": DbrxFlashAttention,
+ "flash_attention_3": DbrxFlashAttention,
"sdpa": DbrxSdpaAttention,
}
@@ -829,6 +833,7 @@ class DbrxPreTrainedModel(PreTrainedModel):
_no_split_modules = ["DbrxBlock"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -1114,7 +1119,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py
index b8eb9f5a8b4222..47c921100c2c94 100755
--- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py
+++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py
@@ -100,9 +100,9 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
return model
-# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2
+# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2Config->DecisionTransformerConfig,GPT2->DecisionTransformerGPT2
class DecisionTransformerGPT2Attention(nn.Module):
- def __init__(self, config, is_cross_attention=False, layer_idx=None):
+ def __init__(self, config: DecisionTransformerConfig, is_cross_attention=False, layer_idx=None):
super().__init__()
self.config = config
max_positions = config.max_position_embeddings
diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py
index 36e35594b3d3c6..73f05c37dec80b 100755
--- a/src/transformers/models/distilbert/modeling_distilbert.py
+++ b/src/transformers/models/distilbert/modeling_distilbert.py
@@ -238,14 +238,14 @@ def unshape(x: torch.Tensor) -> torch.Tensor:
return (context,)
-class DistilBertFlashAttention2(MultiHeadSelfAttention):
+class DistilBertFlashAttention(MultiHeadSelfAttention):
"""
DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module
stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
API of flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -253,6 +253,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -324,6 +325,7 @@ def reshape(x: torch.Tensor) -> torch.Tensor:
dropout=attn_dropout,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head)
@@ -438,8 +440,9 @@ def ff_chunk(self, input: torch.Tensor) -> torch.Tensor:
DISTILBERT_ATTENTION_CLASSES = {
"eager": MultiHeadSelfAttention,
- "flash_attention_2": DistilBertFlashAttention2,
"sdpa": DistilBertSdpaAttention,
+ "flash_attention_2": DistilBertFlashAttention,
+ "flash_attention_3": DistilBertFlashAttention,
}
@@ -590,6 +593,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
base_model_prefix = "distilbert"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module: nn.Module):
@@ -677,6 +681,7 @@ def __init__(self, config: PretrainedConfig):
self.embeddings = Embeddings(config) # Embeddings
self.transformer = Transformer(config) # Encoder
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
# Initialize weights and apply final processing
@@ -784,7 +789,7 @@ def forward(
embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim)
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
if attention_mask is None:
diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py
index 6928831f0187fb..f8c49f4341f158 100644
--- a/src/transformers/models/falcon/modeling_falcon.py
+++ b/src/transformers/models/falcon/modeling_falcon.py
@@ -56,6 +56,7 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
logger = logging.get_logger(__name__)
@@ -521,14 +522,14 @@ def forward(
return attn_output, layer_past
-class FalconFlashAttention2(FalconAttention):
+class FalconFlashAttention(FalconAttention):
"""
Falcon flash attention module. This module inherits from `FalconAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -536,6 +537,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -624,6 +626,7 @@ def forward(
dropout=attn_dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
@@ -654,7 +657,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
FALCON_ATTENTION_CLASSES = {
"eager": FalconAttention,
"sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA
- "flash_attention_2": FalconFlashAttention2,
+ "flash_attention_2": FalconFlashAttention,
+ "flash_attention_3": FalconFlashAttention,
}
@@ -856,6 +860,7 @@ class FalconPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["FalconDecoderLayer"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -918,6 +923,7 @@ def __init__(self, config: FalconConfig):
# Transformer blocks
self.h = nn.ModuleList([FalconDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
# Final Layer Norm
@@ -1108,7 +1114,10 @@ def _update_causal_mask(
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py
index c6070a3d96b6d2..03cdcbb7daf670 100644
--- a/src/transformers/models/gemma/modeling_gemma.py
+++ b/src/transformers/models/gemma/modeling_gemma.py
@@ -388,7 +388,7 @@ def forward(
return attn_output, None, past_key_value
-class GemmaFlashAttention2(GemmaAttention):
+class GemmaFlashAttention(GemmaAttention):
"""
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -402,6 +402,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -487,6 +488,7 @@ def forward(
sliding_window=getattr(self, "sliding_window", None),
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
@@ -499,7 +501,8 @@ def forward(
GEMMA_ATTENTION_CLASSES = {
"eager": GemmaAttention,
- "flash_attention_2": GemmaFlashAttention2,
+ "flash_attention_2": GemmaFlashAttention,
+ "flash_attention_3": GemmaFlashAttention,
"sdpa": GemmaSdpaAttention,
}
@@ -605,6 +608,7 @@ class GemmaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GemmaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -870,7 +874,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py
index ca89b6cf2a6da8..e3f6977f36bff0 100644
--- a/src/transformers/models/gemma/modular_gemma.py
+++ b/src/transformers/models/gemma/modular_gemma.py
@@ -30,7 +30,7 @@
from ...utils import is_torchdynamo_compiling, logging
from ..llama.modeling_llama import (
LlamaDecoderLayer,
- LlamaFlashAttention2,
+ LlamaFlashAttention,
LlamaForCausalLM,
LlamaForSequenceClassification,
LlamaForTokenClassification,
@@ -456,7 +456,7 @@ def forward(
return attn_output, None, past_key_value
-class GemmaFlashAttention2(LlamaFlashAttention2, GemmaAttention):
+class GemmaFlashAttention(LlamaFlashAttention, GemmaAttention):
"""
Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -547,6 +547,7 @@ def forward(
sliding_window=getattr(self, "sliding_window", None),
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
@@ -559,7 +560,8 @@ def forward(
GEMMA_ATTENTION_CLASSES = {
"eager": GemmaAttention,
- "flash_attention_2": GemmaFlashAttention2,
+ "flash_attention_2": GemmaFlashAttention,
+ "flash_attention_3": GemmaFlashAttention,
"sdpa": GemmaSdpaAttention,
}
diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py
index c52b7b82e13d61..6b8233f33a86f6 100644
--- a/src/transformers/models/gemma2/modeling_gemma2.py
+++ b/src/transformers/models/gemma2/modeling_gemma2.py
@@ -271,7 +271,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-class Gemma2FlashAttention2(Gemma2Attention):
+class Gemma2FlashAttention(Gemma2Attention):
"""
Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -285,6 +285,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -375,6 +376,7 @@ def forward(
sliding_window=self.sliding_window,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -481,7 +483,8 @@ def forward(
GEMMA2_ATTENTION_CLASSES = {
"eager": Gemma2Attention,
- "flash_attention_2": Gemma2FlashAttention2,
+ "flash_attention_2": Gemma2FlashAttention,
+ "flash_attention_3": Gemma2FlashAttention,
"sdpa": Gemma2SdpaAttention,
}
@@ -531,7 +534,10 @@ def forward(
"""
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
else:
@@ -605,6 +611,7 @@ class Gemma2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Gemma2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = False
@@ -877,7 +884,10 @@ def _update_causal_mask(
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
# as it doesn't cause dynamic control issues.
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
return attention_mask
dtype, device = input_tensor.dtype, input_tensor.device
@@ -1143,6 +1153,7 @@ def prepare_inputs_for_generation(
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
and not self.config._attn_implementation == "flash_attention_2"
+ and not self.config._attn_implementation == "flash_attention_3"
):
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py
index ff53955716e69f..14efb981a42fec 100644
--- a/src/transformers/models/gemma2/modular_gemma2.py
+++ b/src/transformers/models/gemma2/modular_gemma2.py
@@ -283,7 +283,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-class Gemma2FlashAttention2(Gemma2Attention):
+class Gemma2FlashAttention(Gemma2Attention):
"""
Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -297,6 +297,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -387,6 +388,7 @@ def forward(
sliding_window=self.sliding_window,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -513,7 +515,10 @@ def forward(
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding
# Flash-attn is a 2D tensor
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if past_key_value is not None: # when decoding
attention_mask = attention_mask[:, -self.sliding_window :]
else:
@@ -715,7 +720,10 @@ def _update_causal_mask(
# So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape
# to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible
# as it doesn't cause dynamic control issues.
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
return attention_mask
dtype, device = input_tensor.dtype, input_tensor.device
@@ -885,6 +893,7 @@ def prepare_inputs_for_generation(
isinstance(past_key_values, HybridCache)
and attention_mask.ndim == 2
and not self.config._attn_implementation == "flash_attention_2"
+ and not self.config._attn_implementation == "flash_attention_3"
):
if model_inputs["inputs_embeds"] is not None:
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape
diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py
index cf8edfe474880f..3fe253cc192666 100644
--- a/src/transformers/models/git/modeling_git.py
+++ b/src/transformers/models/git/modeling_git.py
@@ -719,7 +719,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
class GitVisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config):
+ def __init__(self, config: GitVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py
index e99f4b126246d8..51c70cde62fa2d 100644
--- a/src/transformers/models/gpt2/modeling_gpt2.py
+++ b/src/transformers/models/gpt2/modeling_gpt2.py
@@ -121,7 +121,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
class GPT2Attention(nn.Module):
- def __init__(self, config, is_cross_attention=False, layer_idx=None):
+ def __init__(self, config: GPT2Config, is_cross_attention=False, layer_idx=None):
super().__init__()
self.config = config
max_positions = config.max_position_embeddings
@@ -342,14 +342,14 @@ def forward(
return outputs # a, present, (attentions)
-class GPT2FlashAttention2(GPT2Attention):
+class GPT2FlashAttention(GPT2Attention):
"""
GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -357,6 +357,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -442,6 +443,7 @@ def forward(
dropout=attn_dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
@@ -579,7 +581,12 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
return hidden_states
-GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention}
+GPT2_ATTENTION_CLASSES = {
+ "eager": GPT2Attention,
+ "flash_attention_2": GPT2FlashAttention,
+ "flash_attention_3": GPT2FlashAttention,
+ "sdpa": GPT2SdpaAttention,
+}
class GPT2Block(nn.Module):
@@ -675,6 +682,7 @@ class GPT2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPT2Block"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def __init__(self, *inputs, **kwargs):
@@ -1032,7 +1040,7 @@ def forward(
# Attention mask.
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
- if self._attn_implementation == "flash_attention_2":
+ if self._attn_implementation == "flash_attention_2" or self._attn_implementation == "flash_attention_3":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif _use_sdpa:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
@@ -1069,7 +1077,10 @@ def forward(
encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1]
)
- elif not self._attn_implementation == "flash_attention_2":
+ elif (
+ not self._attn_implementation == "flash_attention_2"
+ and not self._attn_implementation == "flash_attention_3"
+ ):
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
index ca1c03fcd9f911..7f5c110144cd03 100644
--- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
+++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py
@@ -84,7 +84,7 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor
class GPTBigCodeAttention(nn.Module):
- def __init__(self, config, is_cross_attention=False, layer_idx=None):
+ def __init__(self, config: GPTBigCodeConfig, is_cross_attention=False, layer_idx=None):
super().__init__()
self.config = config
@@ -271,14 +271,14 @@ def forward(
return outputs # a, present, (attentions)
-class GPTBigCodeFlashAttention2(GPTBigCodeAttention):
+class GPTBigCodeFlashAttention(GPTBigCodeAttention):
"""
GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module
stays untouched. The only required change would be on the forward pass where it needs to correctly call the public
API of flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -286,6 +286,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -377,6 +378,7 @@ def forward(
dropout=attn_dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
@@ -561,7 +563,8 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl
GPTBIGCODE_ATTENTION_CLASSES = {
"eager": GPTBigCodeAttention,
- "flash_attention_2": GPTBigCodeFlashAttention2,
+ "flash_attention_2": GPTBigCodeFlashAttention,
+ "flash_attention_3": GPTBigCodeFlashAttention,
"sdpa": GPTBigCodeSdpaAttention,
}
@@ -667,6 +670,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTBigCodeBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def __init__(self, *inputs, **kwargs):
@@ -811,6 +815,7 @@ def __init__(self, config):
self._use_sdpa = config._attn_implementation == "sdpa"
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
# Initialize weights and apply final processing
self.post_init()
@@ -892,7 +897,7 @@ def forward(
key_length = past_length + query_length
self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length]
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None
encoder_attention_mask = (
diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
index 234f0f6f10dbb3..3158449a45e1fa 100755
--- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py
+++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py
@@ -151,7 +151,7 @@ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path):
class GPTNeoSelfAttention(nn.Module):
- def __init__(self, config, attention_type, layer_id=None):
+ def __init__(self, config: GPTNeoConfig, attention_type, layer_id=None):
super().__init__()
self.config = config
@@ -271,14 +271,14 @@ def forward(
return outputs # a, past_kv, (attentions)
-class GPTNeoFlashAttention2(GPTNeoSelfAttention):
+class GPTNeoFlashAttention(GPTNeoSelfAttention):
"""
GPTNeo flash attention module. This module inherits from `GPTNeoSelfAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -286,6 +286,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -360,6 +361,7 @@ def forward(
softmax_scale=1.0,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
@@ -375,7 +377,8 @@ def forward(
GPT_NEO_ATTENTION_CLASSES = {
"eager": GPTNeoSelfAttention,
- "flash_attention_2": GPTNeoFlashAttention2,
+ "flash_attention_2": GPTNeoFlashAttention,
+ "flash_attention_3": GPTNeoFlashAttention,
}
@@ -497,6 +500,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTNeoBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = False # TODO: needs a HybridCache
@@ -798,7 +802,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
index 60552106d61702..daa86d5aaf714d 100755
--- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py
+++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py
@@ -53,6 +53,7 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM"
@@ -72,6 +73,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTNeoXLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
@@ -93,7 +95,7 @@ def _init_weights(self, module):
class GPTNeoXAttention(nn.Module):
- def __init__(self, config, layer_idx=None):
+ def __init__(self, config: GPTNeoXConfig, layer_idx=None):
super().__init__()
self.config = config
self.num_attention_heads = config.num_attention_heads
@@ -302,14 +304,14 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None):
return attn_output, attn_weights
-class GPTNeoXFlashAttention2(GPTNeoXAttention):
+class GPTNeoXFlashAttention(GPTNeoXAttention):
"""
GPTNeoX flash attention module. This module inherits from `GPTNeoXAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -317,6 +319,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -391,6 +394,7 @@ def forward(
softmax_scale=self.norm_factor,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
# Reshape outputs
@@ -673,7 +677,8 @@ def forward(self, hidden_states):
GPT_NEOX_ATTENTION_CLASSES = {
"eager": GPTNeoXAttention,
- "flash_attention_2": GPTNeoXFlashAttention2,
+ "flash_attention_2": GPTNeoXFlashAttention,
+ "flash_attention_3": GPTNeoXFlashAttention,
"sdpa": GPTNeoXSdpaAttention,
}
@@ -995,7 +1000,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
index 2fdb730e7ca1b3..5ac74fdb1cac2f 100755
--- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
+++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py
@@ -699,7 +699,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py
index f6fe90fc6c5618..20a0bfb0db5b36 100644
--- a/src/transformers/models/gptj/modeling_gptj.py
+++ b/src/transformers/models/gptj/modeling_gptj.py
@@ -83,7 +83,7 @@ def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Ten
class GPTJAttention(nn.Module):
- def __init__(self, config, layer_idx=None):
+ def __init__(self, config: GPTJConfig, layer_idx=None):
super().__init__()
self.config = config
max_positions = config.max_position_embeddings
@@ -259,14 +259,14 @@ def forward(
return outputs # a, present, (attentions)
-class GPTJFlashAttention2(GPTJAttention):
+class GPTJFlashAttention(GPTJAttention):
"""
GPTJ flash attention module. This module inherits from `GPTJAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -274,6 +274,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -388,6 +389,7 @@ def forward(
dropout=attention_dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
# Reshape outputs
@@ -406,7 +408,8 @@ def forward(
GPTJ_ATTENTION_CLASSES = {
"eager": GPTJAttention,
- "flash_attention_2": GPTJFlashAttention2,
+ "flash_attention_2": GPTJFlashAttention,
+ "flash_attention_3": GPTJFlashAttention,
}
@@ -487,6 +490,7 @@ class GPTJPreTrainedModel(PreTrainedModel):
_no_split_modules = ["GPTJBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
@@ -667,6 +671,7 @@ def __init__(self, config):
self.post_init()
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
@add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None):
@@ -893,7 +898,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py
index ff6cf73cef4e3f..033769bbf95379 100644
--- a/src/transformers/models/granite/modeling_granite.py
+++ b/src/transformers/models/granite/modeling_granite.py
@@ -294,7 +294,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-class GraniteFlashAttention2(GraniteAttention):
+class GraniteFlashAttention(GraniteAttention):
"""
Granite flash attention module. This module inherits from `GraniteAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -308,6 +308,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -389,6 +390,7 @@ def forward(
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -493,7 +495,8 @@ def forward(
GRANITE_ATTENTION_CLASSES = {
"eager": GraniteAttention,
- "flash_attention_2": GraniteFlashAttention2,
+ "flash_attention_2": GraniteFlashAttention,
+ "flash_attention_3": GraniteFlashAttention,
"sdpa": GraniteSdpaAttention,
}
@@ -609,6 +612,7 @@ class GranitePreTrainedModel(PreTrainedModel):
_no_split_modules = ["GraniteDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -886,7 +890,10 @@ def _update_causal_mask(
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py
index ebb74176094a05..648d9690b84bc5 100644
--- a/src/transformers/models/granitemoe/modeling_granitemoe.py
+++ b/src/transformers/models/granitemoe/modeling_granitemoe.py
@@ -510,8 +510,8 @@ def forward(
return attn_output, attn_weights, past_key_value
-# Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention2 with Granite->GraniteMoe
-class GraniteMoeFlashAttention2(GraniteMoeAttention):
+# Copied from transformers.models.granite.modeling_granite.GraniteFlashAttention with Granite->GraniteMoe
+class GraniteMoeFlashAttention(GraniteMoeAttention):
"""
GraniteMoe flash attention module. This module inherits from `GraniteMoeAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -525,6 +525,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -606,6 +607,7 @@ def forward(
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -711,7 +713,8 @@ def forward(
GRANITEMOE_ATTENTION_CLASSES = {
"eager": GraniteMoeAttention,
- "flash_attention_2": GraniteMoeFlashAttention2,
+ "flash_attention_2": GraniteMoeFlashAttention,
+ "flash_attention_3": GraniteMoeFlashAttention,
"sdpa": GraniteMoeSdpaAttention,
}
@@ -835,6 +838,7 @@ class GraniteMoePreTrainedModel(PreTrainedModel):
_no_split_modules = ["GraniteMoeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -1119,7 +1123,10 @@ def _update_causal_mask(
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py
index 08760a3f4ab238..a8f790c27e80a5 100755
--- a/src/transformers/models/hubert/modeling_hubert.py
+++ b/src/transformers/models/hubert/modeling_hubert.py
@@ -550,15 +550,15 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Hubert
-class HubertFlashAttention2(HubertAttention):
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->Hubert
+class HubertFlashAttention(HubertAttention):
"""
Hubert flash attention module. This module inherits from `HubertAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -566,6 +566,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
@@ -667,6 +668,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
@@ -788,7 +790,8 @@ def forward(
HUBERT_ATTENTION_CLASSES = {
"eager": HubertAttention,
"sdpa": HubertSdpaAttention,
- "flash_attention_2": HubertFlashAttention2,
+ "flash_attention_2": HubertFlashAttention,
+ "flash_attention_3": HubertFlashAttention,
}
@@ -936,6 +939,7 @@ def __init__(self, config):
self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -952,7 +956,7 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1024,6 +1028,7 @@ def __init__(self, config):
)
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -1040,7 +1045,7 @@ def forward(
# make sure padded tokens are not attended to
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1112,6 +1117,7 @@ class HubertPreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py
index c43ba3e9a6b74a..6f08361ee7ebe2 100644
--- a/src/transformers/models/idefics/modeling_idefics.py
+++ b/src/transformers/models/idefics/modeling_idefics.py
@@ -1425,7 +1425,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py
index 5339b706924d8f..8ee1c0f7d03bb8 100644
--- a/src/transformers/models/idefics/vision.py
+++ b/src/transformers/models/idefics/vision.py
@@ -164,7 +164,7 @@ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: boo
class IdeficsVisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config):
+ def __init__(self, config: IdeficsVisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py
index 056811138155f3..df7298ea4615b8 100644
--- a/src/transformers/models/idefics2/modeling_idefics2.py
+++ b/src/transformers/models/idefics2/modeling_idefics2.py
@@ -191,7 +191,7 @@ class Idefics2VisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
# Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
- def __init__(self, config):
+ def __init__(self, config: Idefics2VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@@ -266,14 +266,14 @@ def forward(
return attn_output, attn_weights
-class Idefics2VisionFlashAttention2(Idefics2VisionAttention):
+class Idefics2VisionFlashAttention(Idefics2VisionAttention):
"""
Idefics2Vision flash attention module. This module inherits from `Idefics2VisionAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -281,6 +281,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -353,6 +354,7 @@ def forward(
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
@@ -366,7 +368,8 @@ def forward(
IDEFICS_VISION_ATTENTION_CLASSES = {
"eager": Idefics2VisionAttention,
- "flash_attention_2": Idefics2VisionFlashAttention2,
+ "flash_attention_2": Idefics2VisionFlashAttention,
+ "flash_attention_3": Idefics2VisionFlashAttention,
}
@@ -583,6 +586,7 @@ def __init__(self, config: Idefics2VisionConfig):
self.encoder = Idefics2Encoder(config)
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def get_input_embeddings(self):
return self.embeddings
@@ -624,7 +628,7 @@ def forward(
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
patch_attention_mask = None
- elif not self._use_flash_attention_2:
+ elif not self._use_flash_attention_2 and not self._use_flash_attention_3:
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
@@ -683,7 +687,7 @@ def extra_repr(self):
class Idefics2PerceiverAttention(nn.Module):
- def __init__(self, config, layer_idx: Optional[int] = None) -> None:
+ def __init__(self, config: Idefics2Config, layer_idx: Optional[int] = None) -> None:
"""Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`"""
super().__init__()
@@ -783,15 +787,15 @@ def forward(
return attn_output, attn_weights, past_key_value
-# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2
-class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention):
+# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2
+class Idefics2PerceiverFlashAttention(Idefics2PerceiverAttention):
"""
Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -799,6 +803,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
# Ignore copy
def forward(
@@ -899,6 +904,7 @@ def forward(
sliding_window=None,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
@@ -912,7 +918,8 @@ def forward(
IDEFICS2_PERCEIVER_ATTENTION_CLASSES = {
"eager": Idefics2PerceiverAttention,
- "flash_attention_2": Idefics2PerceiverFlashAttention2,
+ "flash_attention_2": Idefics2PerceiverFlashAttention,
+ "flash_attention_3": Idefics2PerceiverFlashAttention,
}
@@ -1011,6 +1018,7 @@ def __init__(self, config) -> None:
self.norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -1026,7 +1034,7 @@ def forward(
attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1)
attention_mask = (
_prepare_4d_attention_mask(attention_mask, latents.dtype, tgt_len=self.n_latents)
- if not self._use_flash_attention_2
+ if not self._use_flash_attention_2 and not self._use_flash_attention_3
else attention_mask
)
@@ -1094,6 +1102,7 @@ class Idefics2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
def _init_weights(self, module):
@@ -1228,6 +1237,7 @@ def __init__(self, config: Idefics2Config):
self.image_token_id = self.config.image_token_id
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self.post_init()
diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py
index bd64e5db681b71..9ea1a172f3b089 100644
--- a/src/transformers/models/idefics3/modeling_idefics3.py
+++ b/src/transformers/models/idefics3/modeling_idefics3.py
@@ -264,15 +264,15 @@ def forward(
return attn_output, attn_weights
-# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionFlashAttention2 with Idefics2->Idefics3
-class Idefics3VisionFlashAttention2(Idefics3VisionAttention):
+# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionFlashAttention with Idefics2->Idefics3
+class Idefics3VisionFlashAttention(Idefics3VisionAttention):
"""
Idefics3Vision flash attention module. This module inherits from `Idefics3VisionAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -280,6 +280,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -352,6 +353,7 @@ def forward(
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous()
@@ -365,7 +367,8 @@ def forward(
IDEFICS_VISION_ATTENTION_CLASSES = {
"eager": Idefics3VisionAttention,
- "flash_attention_2": Idefics3VisionFlashAttention2,
+ "flash_attention_2": Idefics3VisionFlashAttention,
+ "flash_attention_3": Idefics3VisionFlashAttention,
}
@@ -620,6 +623,7 @@ class Idefics3PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Idefics3VisionAttention", "Idefics3DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2PreTrainedModel._init_weights
@@ -676,6 +680,7 @@ def __init__(self, config: Idefics3VisionConfig):
self.patch_size = config.patch_size
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.get_input_embeddings
def get_input_embeddings(self):
@@ -719,7 +724,7 @@ def forward(
# avoiding passing the attention_mask, which is equivalent to attending to the full sequence
if not torch.any(~patch_attention_mask):
patch_attention_mask = None
- elif not self._use_flash_attention_2:
+ elif not self._use_flash_attention_2 and not self._use_flash_attention_3:
patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
@@ -835,6 +840,7 @@ def __init__(self, config: Idefics3Config):
self.image_token_id = self.config.image_token_id
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self.post_init()
diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py
index 07f84b362eee7a..804a1009456c1e 100755
--- a/src/transformers/models/jamba/modeling_jamba.py
+++ b/src/transformers/models/jamba/modeling_jamba.py
@@ -378,15 +378,15 @@ def forward(
return attn_output, attn_weights, past_key_value
-# Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba
-class JambaFlashAttention2(JambaAttention):
+# Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention with Mistral->Jamba
+class JambaFlashAttention(JambaAttention):
"""
Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -394,6 +394,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -491,6 +492,7 @@ def forward(
sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -587,7 +589,8 @@ def forward(
JAMBA_ATTENTION_CLASSES = {
"eager": JambaAttention,
- "flash_attention_2": JambaFlashAttention2,
+ "flash_attention_2": JambaFlashAttention,
+ "flash_attention_3": JambaFlashAttention,
"sdpa": JambaSdpaAttention,
}
@@ -1125,6 +1128,7 @@ class JambaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache
_is_stateful = True
@@ -1381,7 +1385,10 @@ def forward(
)
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py
index 8b39183b8fc6a6..6f0ec0136bcc8c 100644
--- a/src/transformers/models/jetmoe/modeling_jetmoe.py
+++ b/src/transformers/models/jetmoe/modeling_jetmoe.py
@@ -48,6 +48,7 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "jetmoe"
@@ -641,8 +642,8 @@ def forward(
return attn_output, None, past_key_value, router_logits
-class JetMoeFlashAttention2(JetMoeAttention):
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+class JetMoeFlashAttention(JetMoeAttention):
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -650,6 +651,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -744,6 +746,7 @@ def forward(
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
).to(input_dtype)
# output projection
@@ -759,7 +762,8 @@ def forward(
JETMOE_ATTENTION_CLASSES = {
"eager": JetMoeAttention,
- "flash_attention_2": JetMoeFlashAttention2,
+ "flash_attention_2": JetMoeFlashAttention,
+ "flash_attention_3": JetMoeFlashAttention,
"sdpa": JetMoeSdpaAttention,
}
@@ -832,6 +836,7 @@ class JetMoePreTrainedModel(PreTrainedModel):
_no_split_modules = ["JetMoeBlock"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
@@ -1008,7 +1013,11 @@ def forward(
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
- if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache:
+ if (
+ attention_mask is not None
+ and (self._attn_implementation == "flash_attention_2" or self._attn_implementation == "flash_attention_3")
+ and use_cache
+ ):
batch_size = inputs_embeds.shape[0]
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
if is_padding_right:
@@ -1096,7 +1105,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py
index 5adc48a3a2ef59..68d6d8e2ac978d 100644
--- a/src/transformers/models/kosmos2/modeling_kosmos2.py
+++ b/src/transformers/models/kosmos2/modeling_kosmos2.py
@@ -470,7 +470,7 @@ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=Fals
class Kosmos2VisionAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config):
+ def __init__(self, config: Kosmos2VisionConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py
index 99edee6a92a838..8dff4225a70c81 100644
--- a/src/transformers/models/llama/modeling_llama.py
+++ b/src/transformers/models/llama/modeling_llama.py
@@ -397,7 +397,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-class LlamaFlashAttention2(LlamaAttention):
+class LlamaFlashAttention(LlamaAttention):
"""
Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -411,6 +411,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -506,6 +507,7 @@ def forward(
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -618,7 +620,8 @@ def forward(
LLAMA_ATTENTION_CLASSES = {
"eager": LlamaAttention,
- "flash_attention_2": LlamaFlashAttention2,
+ "flash_attention_2": LlamaFlashAttention,
+ "flash_attention_3": LlamaFlashAttention,
"sdpa": LlamaSdpaAttention,
}
@@ -731,6 +734,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlamaDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -990,7 +994,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py
index 9cb4d1f5a9aadb..6158fe87f05438 100644
--- a/src/transformers/models/llava/modeling_llava.py
+++ b/src/transformers/models/llava/modeling_llava.py
@@ -126,6 +126,7 @@ class LlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
def _init_weights(self, module):
diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py
index 6ece93b6f7a860..f5dd454f2c9484 100644
--- a/src/transformers/models/llava_next/modeling_llava_next.py
+++ b/src/transformers/models/llava_next/modeling_llava_next.py
@@ -235,6 +235,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaNextVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
def _init_weights(self, module):
diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py
index 95a69826f6a02e..47307d5afa1331 100644
--- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py
+++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py
@@ -280,6 +280,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaNextVideoVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
def _init_weights(self, module):
diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py
index c378ff09f1e4ad..a30dbd694d7ddc 100644
--- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py
+++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py
@@ -241,6 +241,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel):
_no_split_modules = ["LlavaOnevisionVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
_supports_static_cache = False # Qwen2 doesn't but llava has no reasons to not support
_supports_quantized_cache = True
diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py
index 9856eec0c229fa..3bc8eb2219760c 100755
--- a/src/transformers/models/m2m_100/modeling_m2m_100.py
+++ b/src/transformers/models/m2m_100/modeling_m2m_100.py
@@ -339,15 +339,15 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->M2M100
-class M2M100FlashAttention2(M2M100Attention):
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->M2M100
+class M2M100FlashAttention(M2M100Attention):
"""
M2M100 flash attention module. This module inherits from `M2M100Attention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -355,6 +355,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
@@ -456,6 +457,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
@@ -647,8 +649,9 @@ def forward(
M2M100_ATTENTION_CLASSES = {
"eager": M2M100Attention,
- "flash_attention_2": M2M100FlashAttention2,
"sdpa": M2M100SdpaAttention,
+ "flash_attention_2": M2M100FlashAttention,
+ "flash_attention_3": M2M100FlashAttention,
}
@@ -779,6 +782,7 @@ class M2M100PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
@@ -953,6 +957,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] =
self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
self.gradient_checkpointing = False
@@ -1034,7 +1039,7 @@ def forward(
# expand attention_mask
if attention_mask is not None:
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
attention_mask = attention_mask if 0 in attention_mask else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
@@ -1135,6 +1140,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] =
)
self.layers = nn.ModuleList([M2M100DecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layer_norm = nn.LayerNorm(config.d_model)
@@ -1246,7 +1252,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
combined_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
@@ -1266,7 +1272,7 @@ def forward(
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
@@ -1406,7 +1412,7 @@ def __init__(self, config: M2M100Config):
self.encoder = M2M100Encoder(config, self.shared)
self.decoder = M2M100Decoder(config, self.shared)
- if config._attn_implementation == "flash_attention_2":
+ if config._attn_implementation == "flash_attention_2" or config._attn_implementation == "flash_attention_3":
logger.warning_once(
"Attention with Flash Attention 2 does not support `layer_head_mask`. If you need this feature, please use standard attention."
)
diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py
index ebb325073f9378..cc4eae5da5bbf4 100755
--- a/src/transformers/models/mbart/modeling_mbart.py
+++ b/src/transformers/models/mbart/modeling_mbart.py
@@ -283,15 +283,15 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MBart
-class MBartFlashAttention2(MBartAttention):
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->MBart
+class MBartFlashAttention(MBartAttention):
"""
MBart flash attention module. This module inherits from `MBartAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -299,6 +299,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
@@ -400,6 +401,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
@@ -521,7 +523,8 @@ def forward(
MBART_ATTENTION_CLASSES = {
"eager": MBartAttention,
"sdpa": MBartSdpaAttention,
- "flash_attention_2": MBartFlashAttention2,
+ "flash_attention_2": MBartFlashAttention,
+ "flash_attention_3": MBartFlashAttention,
}
@@ -746,6 +749,7 @@ class MBartPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MBartDecoderLayer", "MBartAttention"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
@@ -1044,7 +1048,10 @@ def forward(
# expand attention_mask
if attention_mask is not None:
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
attention_mask = attention_mask if 0 in attention_mask else None
elif self.config._attn_implementation == "sdpa" and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
@@ -1261,7 +1268,10 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self.config._attn_implementation == "sdpa" and not output_attentions and cross_attn_head_mask is None:
@@ -1281,7 +1291,10 @@ def forward(
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.config._attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py
index 2a9faa29f0d358..38a1e0dcd6b7bf 100644
--- a/src/transformers/models/mimi/modeling_mimi.py
+++ b/src/transformers/models/mimi/modeling_mimi.py
@@ -42,6 +42,7 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
logger = logging.get_logger(__name__)
@@ -559,8 +560,8 @@ def forward(
return attn_output, attn_weights, past_key_value
-# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi
-class MimiFlashAttention2(MimiAttention):
+# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention with Gemma->Mimi
+class MimiFlashAttention(MimiAttention):
"""
Mimi flash attention module. This module inherits from `MimiAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -574,6 +575,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -659,6 +661,7 @@ def forward(
sliding_window=getattr(self, "sliding_window", None),
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -761,7 +764,8 @@ def forward(
MIMI_ATTENTION_CLASSES = {
"eager": MimiAttention,
- "flash_attention_2": MimiFlashAttention2,
+ "flash_attention_2": MimiFlashAttention,
+ "flash_attention_3": MimiFlashAttention,
"sdpa": MimiSdpaAttention,
}
@@ -1036,7 +1040,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
@@ -1371,6 +1378,7 @@ class MimiPreTrainedModel(PreTrainedModel):
_no_split_modules = ["MimiDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py
index ffa1a18307e982..2d980aee476193 100644
--- a/src/transformers/models/mistral/modeling_mistral.py
+++ b/src/transformers/models/mistral/modeling_mistral.py
@@ -53,6 +53,7 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "MistralConfig"
@@ -268,14 +269,14 @@ def forward(
return attn_output, attn_weights, past_key_value
-class MistralFlashAttention2(MistralAttention):
+class MistralFlashAttention(MistralAttention):
"""
Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -283,6 +284,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -392,6 +394,7 @@ def forward(
sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
@@ -495,7 +498,8 @@ def forward(
MISTRAL_ATTENTION_CLASSES = {
"eager": MistralAttention,
- "flash_attention_2": MistralFlashAttention2,
+ "flash_attention_2": MistralFlashAttention,
+ "flash_attention_3": MistralFlashAttention,
"sdpa": MistralSdpaAttention,
}
@@ -605,6 +609,7 @@ class MistralPreTrainedModel(PreTrainedModel):
_no_split_modules = ["MistralDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
@@ -859,7 +864,7 @@ def _update_causal_mask(
use_cache: bool,
output_attentions: bool,
):
- if self._attn_implementation == "flash_attention_2":
+ if self._attn_implementation == "flash_attention_2" or self._attn_implementation == "flash_attention_3":
if attention_mask is not None and use_cache:
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
if is_padding_right:
diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py
index e87054cd70f58b..f3fa10cc5e1c1a 100644
--- a/src/transformers/models/mixtral/modeling_mixtral.py
+++ b/src/transformers/models/mixtral/modeling_mixtral.py
@@ -55,6 +55,7 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
@@ -383,9 +384,9 @@ def forward(
return attn_output, attn_weights, past_key_value
-# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral
+# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention with Mistral->Mixtral
# TODO @longjie no longer copied from Mistral after static cache
-class MixtralFlashAttention2(MixtralAttention):
+class MixtralFlashAttention(MixtralAttention):
"""
Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -503,6 +504,7 @@ def forward(
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -608,7 +610,8 @@ def forward(
MIXTRAL_ATTENTION_CLASSES = {
"eager": MixtralAttention,
- "flash_attention_2": MixtralFlashAttention2,
+ "flash_attention_2": MixtralFlashAttention,
+ "flash_attention_3": MixtralFlashAttention,
"sdpa": MixtralSdpaAttention,
}
@@ -810,6 +813,7 @@ class MixtralPreTrainedModel(PreTrainedModel):
_no_split_modules = ["MixtralDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
@@ -1080,7 +1084,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py
index c9f3b88c68d87a..87ffdb9516457d 100644
--- a/src/transformers/models/musicgen/modeling_musicgen.py
+++ b/src/transformers/models/musicgen/modeling_musicgen.py
@@ -64,6 +64,7 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
if TYPE_CHECKING:
from ...generation.streamers import BaseStreamer
@@ -316,15 +317,15 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Musicgen
-class MusicgenFlashAttention2(MusicgenAttention):
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->Musicgen
+class MusicgenFlashAttention(MusicgenAttention):
"""
Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -332,6 +333,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
@@ -433,6 +435,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
@@ -570,7 +573,8 @@ def forward(
MUSICGEN_ATTENTION_CLASSES = {
"eager": MusicgenAttention,
"sdpa": MusicgenSdpaAttention,
- "flash_attention_2": MusicgenFlashAttention2,
+ "flash_attention_2": MusicgenFlashAttention,
+ "flash_attention_3": MusicgenFlashAttention,
}
@@ -708,6 +712,7 @@ class MusicgenPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
@@ -1003,7 +1008,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
- if self.attn_implementation == "flash_attention_2":
+ if self.attn_implementation == "flash_attention_2" or self.attn_implementation == "flash_attention_3":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
@@ -1021,7 +1026,7 @@ def forward(
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self.attn_implementation == "flash_attention_2":
+ if self.attn_implementation == "flash_attention_2" or self.attn_implementation == "flash_attention_3":
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
@@ -1669,6 +1674,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin):
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def __init__(
diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
index 15cad4072dddab..d70114ee828400 100644
--- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
+++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
@@ -56,6 +56,7 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
if TYPE_CHECKING:
from ...generation.streamers import BaseStreamer
@@ -332,15 +333,15 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MusicgenMelody
-class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention):
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->MusicgenMelody
+class MusicgenMelodyFlashAttention(MusicgenMelodyAttention):
"""
MusicgenMelody flash attention module. This module inherits from `MusicgenMelodyAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -348,6 +349,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
@@ -449,6 +451,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
@@ -570,7 +573,8 @@ def forward(
MUSICGEN_MELODY_ATTENTION_CLASSES = {
"eager": MusicgenMelodyAttention,
"sdpa": MusicgenMelodySdpaAttention,
- "flash_attention_2": MusicgenMelodyFlashAttention2,
+ "flash_attention_2": MusicgenMelodyFlashAttention,
+ "flash_attention_3": MusicgenMelodyFlashAttention,
}
@@ -667,6 +671,7 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
@@ -950,7 +955,7 @@ def forward(
input_shape = inputs_embeds.size()[:-1]
- if self.attn_implementation == "flash_attention_2":
+ if self.attn_implementation == "flash_attention_2" or self.attn_implementation == "flash_attention_3":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self.attn_implementation == "sdpa" and not output_attentions:
# output_attentions=True can not be supported when using SDPA, and we fall back on
@@ -1595,6 +1600,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def __init__(
diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py
index 9411f0bcae5a50..d20d96c0ea82a2 100644
--- a/src/transformers/models/nemotron/modeling_nemotron.py
+++ b/src/transformers/models/nemotron/modeling_nemotron.py
@@ -301,8 +301,8 @@ def forward(
return attn_output, attn_weights, past_key_value
-# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
-class NemotronFlashAttention2(NemotronAttention):
+# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron
+class NemotronFlashAttention(NemotronAttention):
"""
Nemotron flash attention module. This module inherits from `NemotronAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -316,6 +316,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
# Ignore copy
def forward(
@@ -404,6 +405,7 @@ def forward(
sliding_window=getattr(self, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
@@ -509,7 +511,8 @@ def forward(
NEMOTRON_ATTENTION_CLASSES = {
"eager": NemotronAttention,
- "flash_attention_2": NemotronFlashAttention2,
+ "flash_attention_2": NemotronFlashAttention,
+ "flash_attention_3": NemotronFlashAttention,
"sdpa": NemotronSdpaAttention,
}
@@ -624,6 +627,7 @@ class NemotronPreTrainedModel(PreTrainedModel):
_no_split_modules = ["NemotronDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -867,7 +871,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py
index 668722fc9e3f86..35ddf5e63b244c 100644
--- a/src/transformers/models/olmo/modeling_olmo.py
+++ b/src/transformers/models/olmo/modeling_olmo.py
@@ -335,14 +335,14 @@ def forward(
return attn_output, attn_weights, past_key_value
-class OlmoFlashAttention2(OlmoAttention):
+class OlmoFlashAttention(OlmoAttention):
"""
OLMo flash attention module. This module inherits from `OlmoAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -350,6 +350,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -434,6 +435,7 @@ def forward(
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -540,7 +542,8 @@ def forward(
OLMO_ATTENTION_CLASSES = {
"eager": OlmoAttention,
- "flash_attention_2": OlmoFlashAttention2,
+ "flash_attention_2": OlmoFlashAttention,
+ "flash_attention_3": OlmoFlashAttention,
"sdpa": OlmoSdpaAttention,
}
@@ -651,6 +654,7 @@ class OlmoPreTrainedModel(PreTrainedModel):
_no_split_modules = ["OlmoDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -909,7 +913,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py
index 875317732ff06b..1c163d8e7089ff 100644
--- a/src/transformers/models/olmoe/modeling_olmoe.py
+++ b/src/transformers/models/olmoe/modeling_olmoe.py
@@ -416,14 +416,14 @@ def forward(
return attn_output, attn_weights, past_key_value
-class OlmoeFlashAttention2(OlmoeAttention):
+class OlmoeFlashAttention(OlmoeAttention):
"""
OLMoE flash attention module. This module inherits from `OlmoeAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -431,6 +431,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -514,6 +515,7 @@ def forward(
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -622,7 +624,8 @@ def forward(
OLMOE_ATTENTION_CLASSES = {
"eager": OlmoeAttention,
- "flash_attention_2": OlmoeFlashAttention2,
+ "flash_attention_2": OlmoeFlashAttention,
+ "flash_attention_3": OlmoeFlashAttention,
"sdpa": OlmoeSdpaAttention,
}
@@ -791,6 +794,7 @@ class OlmoePreTrainedModel(PreTrainedModel):
_no_split_modules = ["OlmoeDecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -1067,7 +1071,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py
index f7782b8f6172b9..9202df3e213a6c 100644
--- a/src/transformers/models/opt/modeling_opt.py
+++ b/src/transformers/models/opt/modeling_opt.py
@@ -241,14 +241,14 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-class OptFlashAttention2(OPTAttention):
+class OptFlashAttention(OPTAttention):
"""
OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched.
The only required change would be on the forward pass where it needs to correctly call the public API of flash
attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -256,6 +256,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -349,6 +350,7 @@ def forward(
dropout=attn_dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim)
@@ -362,7 +364,8 @@ def forward(
OPT_ATTENTION_CLASSES = {
"eager": OPTAttention,
- "flash_attention_2": OptFlashAttention2,
+ "flash_attention_2": OptFlashAttention,
+ "flash_attention_3": OptFlashAttention,
}
@@ -489,6 +492,7 @@ class OPTPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["OPTDecoderLayer"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
def _init_weights(self, module):
std = self.config.init_std
@@ -605,6 +609,7 @@ def __init__(self, config: OPTConfig):
self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@@ -703,7 +708,7 @@ def forward(
mask_seq_length = past_key_values_length + seq_length
# embed positions
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
attention_mask = (
diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py
index b5fddce1d6a914..a2211ea7c138f1 100644
--- a/src/transformers/models/paligemma/modeling_paligemma.py
+++ b/src/transformers/models/paligemma/modeling_paligemma.py
@@ -193,6 +193,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["PaliGemmaMultiModalProjector"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = False
+ _supports_flash_attn_3 = False
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py
index 7d40c481ac0685..11ae16efe1c557 100644
--- a/src/transformers/models/persimmon/modeling_persimmon.py
+++ b/src/transformers/models/persimmon/modeling_persimmon.py
@@ -735,7 +735,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py
index cb59bd0df9a1b4..608c312cb3cfac 100644
--- a/src/transformers/models/phi/modeling_phi.py
+++ b/src/transformers/models/phi/modeling_phi.py
@@ -382,14 +382,14 @@ def forward(
return attn_output, attn_weights, past_key_value
-class PhiFlashAttention2(PhiAttention):
+class PhiFlashAttention(PhiAttention):
"""
Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -397,6 +397,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -511,6 +512,7 @@ def forward(
softmax_scale=None,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -650,7 +652,8 @@ def forward(
PHI_ATTENTION_CLASSES = {
"eager": PhiAttention,
- "flash_attention_2": PhiFlashAttention2,
+ "flash_attention_2": PhiFlashAttention,
+ "flash_attention_3": PhiFlashAttention,
"sdpa": PhiSdpaAttention,
}
@@ -759,6 +762,7 @@ class PhiPreTrainedModel(PreTrainedModel):
_no_split_modules = ["PhiDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
@@ -877,6 +881,7 @@ def __init__(self, config: PhiConfig):
self.rotary_emb = PhiRotaryEmbedding(config=config)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
self.gradient_checkpointing = False
@@ -1027,7 +1032,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py
index 1c1bb34171b613..0141fcf26e0aed 100644
--- a/src/transformers/models/phi3/modeling_phi3.py
+++ b/src/transformers/models/phi3/modeling_phi3.py
@@ -51,6 +51,7 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct"
@@ -428,14 +429,14 @@ def forward(
return attn_output, attn_weights, past_key_value
-class Phi3FlashAttention2(Phi3Attention):
+class Phi3FlashAttention(Phi3Attention):
"""
Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -443,6 +444,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -568,6 +570,7 @@ def forward(
sliding_window=getattr(self.config, "sliding_window", None),
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -675,7 +678,8 @@ def forward(
PHI3_ATTENTION_CLASSES = {
"eager": Phi3Attention,
- "flash_attention_2": Phi3FlashAttention2,
+ "flash_attention_2": Phi3FlashAttention,
+ "flash_attention_3": Phi3FlashAttention,
"sdpa": Phi3SdpaAttention,
}
@@ -789,6 +793,7 @@ class Phi3PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Phi3DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
@@ -1048,7 +1053,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py
index d15e079770a3ed..1757536a113bf0 100644
--- a/src/transformers/models/plbart/modeling_plbart.py
+++ b/src/transformers/models/plbart/modeling_plbart.py
@@ -686,6 +686,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] =
)
self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
@@ -774,7 +775,7 @@ def forward(
# expand attention_mask
if attention_mask is not None:
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
attention_mask = attention_mask if 0 in attention_mask else None
elif self._use_sdpa and head_mask is None and not output_attentions:
# output_attentions=True & head_mask can not be supported when using SDPA, fall back to
@@ -871,6 +872,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] =
)
self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)])
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
@@ -991,7 +993,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input)
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None:
@@ -1011,7 +1013,7 @@ def forward(
# expand encoder attention mask
if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
# output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py
index 9a970a4a1b2fc6..c381eeb3e0a1d5 100644
--- a/src/transformers/models/qwen2/modeling_qwen2.py
+++ b/src/transformers/models/qwen2/modeling_qwen2.py
@@ -342,7 +342,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-class Qwen2FlashAttention2(Qwen2Attention):
+class Qwen2FlashAttention(Qwen2Attention):
"""
Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention`
as the weights of the module stays untouched. The only required change would be on the forward pass
@@ -351,7 +351,7 @@ class Qwen2FlashAttention2(Qwen2Attention):
config.max_window_layers layers.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -359,6 +359,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -476,6 +477,7 @@ def forward(
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -585,7 +587,8 @@ def forward(
QWEN2_ATTENTION_CLASSES = {
"eager": Qwen2Attention,
- "flash_attention_2": Qwen2FlashAttention2,
+ "flash_attention_2": Qwen2FlashAttention,
+ "flash_attention_3": Qwen2FlashAttention,
"sdpa": Qwen2SdpaAttention,
}
@@ -595,7 +598,11 @@ def __init__(self, config: Qwen2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- if config.sliding_window and config._attn_implementation != "flash_attention_2":
+ if (
+ config.sliding_window
+ and config._attn_implementation != "flash_attention_2"
+ and config._attn_implementation != "flash_attention_3"
+ ):
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
@@ -702,6 +709,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Qwen2DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -966,7 +974,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
index a5ac1f83638545..18a26c4128849d 100644
--- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
+++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py
@@ -215,15 +215,15 @@ def forward(
return attn_output, attn_weights, past_key_value
-# Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention2 with Whisper->Qwen2Audio
-class Qwen2AudioFlashAttention2(Qwen2AudioAttention):
+# Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention with Whisper->Qwen2Audio
+class Qwen2AudioFlashAttention(Qwen2AudioAttention):
"""
Qwen2Audio flash attention module. This module inherits from `Qwen2AudioAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -231,6 +231,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -328,6 +329,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, tgt_len, -1)
@@ -440,7 +442,8 @@ def forward(
QWEN2AUDIO_ATTENTION_CLASSES = {
"eager": Qwen2AudioAttention,
- "flash_attention_2": Qwen2AudioFlashAttention2,
+ "flash_attention_2": Qwen2AudioFlashAttention,
+ "flash_attention_3": Qwen2AudioFlashAttention,
"sdpa": Qwen2AudioSdpaAttention,
}
@@ -544,6 +547,7 @@ class Qwen2AudioPreTrainedModel(PreTrainedModel):
_no_split_modules = ["Qwen2AudioAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
def _init_weights(self, module):
# important: this ported version of Qwen2Audio isn't meant for training from scratch - only
diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
index 2274e96245d3c4..f02c2ba2df08a6 100644
--- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
+++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py
@@ -55,6 +55,7 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "Qwen/Qwen1.5-MoE-A2.7B"
@@ -429,8 +430,8 @@ def forward(
return attn_output, attn_weights, past_key_value
-# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe
-class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
+# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention with Qwen2->Qwen2Moe
+class Qwen2MoeFlashAttention(Qwen2MoeAttention):
"""
Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention`
as the weights of the module stays untouched. The only required change would be on the forward pass
@@ -439,7 +440,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention):
config.max_window_layers layers.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -447,6 +448,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -564,6 +566,7 @@ def forward(
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -674,7 +677,8 @@ def forward(
QWEN2MOE_ATTENTION_CLASSES = {
"eager": Qwen2MoeAttention,
- "flash_attention_2": Qwen2MoeFlashAttention2,
+ "flash_attention_2": Qwen2MoeFlashAttention,
+ "flash_attention_3": Qwen2MoeFlashAttention,
"sdpa": Qwen2MoeSdpaAttention,
}
@@ -867,6 +871,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel):
_no_split_modules = ["Qwen2MoeDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
@@ -1147,7 +1152,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
index 85418a134aa17e..8b6c153b0ad95d 100644
--- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
+++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
@@ -45,6 +45,7 @@
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
@@ -59,6 +60,11 @@
else:
flash_attn_varlen_func = None
+if is_flash_attn_3_available():
+ from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
+
+else:
+ flash_attn_3_varlen_func = None
logger = logging.get_logger(__name__)
@@ -323,8 +329,9 @@ def forward(self, x) -> torch.Tensor:
class VisionAttention(nn.Module):
- def __init__(self, dim: int, num_heads: int = 16) -> None:
+ def __init__(self, dim: int, num_heads: int = 16, config: Optional[Qwen2VLVisionConfig] = None) -> None:
super().__init__()
+ self.config = config
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=True)
@@ -357,12 +364,10 @@ def forward(
return attn_output
-class VisionFlashAttention2(nn.Module):
- def __init__(self, dim: int, num_heads: int = 16) -> None:
- super().__init__()
- self.num_heads = num_heads
- self.qkv = nn.Linear(dim, dim * 3, bias=True)
- self.proj = nn.Linear(dim, dim)
+class VisionFlashAttention(VisionAttention):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
@@ -373,19 +378,21 @@ def forward(
k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0)
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
- attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
- seq_length, -1
- )
+ if self._flash_attn_3:
+ attn_output = flash_attn_3_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
+ seq_length, -1
+ )
+ else:
+ attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
+ seq_length, -1
+ )
attn_output = self.proj(attn_output)
return attn_output
-class VisionSdpaAttention(nn.Module):
- def __init__(self, dim: int, num_heads: int = 16) -> None:
- super().__init__()
- self.num_heads = num_heads
- self.qkv = nn.Linear(dim, dim * 3, bias=True)
- self.proj = nn.Linear(dim, dim)
+class VisionSdpaAttention(VisionAttention):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
def forward(
self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None
@@ -410,7 +417,8 @@ def forward(
QWEN2_VL_VISION_ATTENTION_CLASSES = {
"eager": VisionAttention,
- "flash_attention_2": VisionFlashAttention2,
+ "flash_attention_2": VisionFlashAttention,
+ "flash_attention_3": VisionFlashAttention,
"sdpa": VisionSdpaAttention,
}
@@ -423,7 +431,7 @@ def __init__(self, config, attn_implementation: str = "sdpa") -> None:
mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio)
self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation](
- config.embed_dim, num_heads=config.num_heads
+ config.embed_dim, num_heads=config.num_heads, config=config
)
self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act)
@@ -608,7 +616,7 @@ def forward(
return attn_output, attn_weights, past_key_value
-class Qwen2VLFlashAttention2(Qwen2VLAttention):
+class Qwen2VLFlashAttention(Qwen2VLAttention):
"""
Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention`
as the weights of the module stays untouched. The only required change would be on the forward pass
@@ -624,6 +632,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -753,6 +762,7 @@ def forward(
sliding_window=sliding_window,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -868,7 +878,8 @@ def forward(
QWEN2_VL_ATTENTION_CLASSES = {
"eager": Qwen2VLAttention,
- "flash_attention_2": Qwen2VLFlashAttention2,
+ "flash_attention_2": Qwen2VLFlashAttention,
+ "flash_attention_3": Qwen2VLFlashAttention,
"sdpa": Qwen2VLSdpaAttention,
}
@@ -878,7 +889,11 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
- if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
+ if (
+ config.use_sliding_window
+ and config._attn_implementation != "flash_attention_2"
+ and config._attn_implementation != "flash_attention_3"
+ ):
logger.warning_once(
f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
"unexpected results may be encountered."
@@ -985,6 +1000,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel):
_no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
@@ -1228,7 +1244,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
index e0492948998434..07f21ce8be180c 100644
--- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
+++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py
@@ -541,6 +541,7 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["RecurrentGemmaDecoderLayer"]
_skip_keys_device_placement = ["cache"]
_supports_flash_attn_2 = False
+ _supports_flash_attn_3 = False
_supports_sdpa = False # we can't compare with eager for now
_supports_cache_class = True
_supports_quantized_cache = True
diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py
index 5dfe54e24ac20a..c420cc7cc6c67b 100644
--- a/src/transformers/models/sew/modeling_sew.py
+++ b/src/transformers/models/sew/modeling_sew.py
@@ -554,15 +554,15 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->SEW
-class SEWFlashAttention2(SEWAttention):
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->SEW
+class SEWFlashAttention(SEWAttention):
"""
SEW flash attention module. This module inherits from `SEWAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -570,6 +570,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
@@ -671,6 +672,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
@@ -792,7 +794,8 @@ def forward(
SEW_ATTENTION_CLASSES = {
"eager": SEWAttention,
"sdpa": SEWSdpaAttention,
- "flash_attention_2": SEWFlashAttention2,
+ "flash_attention_2": SEWFlashAttention,
+ "flash_attention_3": SEWFlashAttention,
}
@@ -869,6 +872,7 @@ def __init__(self, config):
self.upsample = SEWUpsampling(config)
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -882,7 +886,7 @@ def forward(
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# make sure padded tokens output 0
hidden_states[~attention_mask] = 0.0
# 2d mask is passed through the layers
@@ -979,6 +983,7 @@ class SEWPreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py
index 1d35d1d44cfd97..13b5305a393c7f 100644
--- a/src/transformers/models/siglip/modeling_siglip.py
+++ b/src/transformers/models/siglip/modeling_siglip.py
@@ -355,8 +355,8 @@ def forward(
class SiglipAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__
- def __init__(self, config):
+ # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ with CLIP->Siglip
+ def __init__(self, config: SiglipConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
@@ -428,7 +428,7 @@ def forward(
return attn_output, attn_weights
-class SiglipFlashAttention2(SiglipAttention):
+class SiglipFlashAttention(SiglipAttention):
"""
SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
@@ -437,7 +437,7 @@ class SiglipFlashAttention2(SiglipAttention):
is_causal = False
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -445,8 +445,9 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
- # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward
+ # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
@@ -511,6 +512,7 @@ def forward(
dropout=dropout_rate,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous()
@@ -590,7 +592,8 @@ def forward(
SIGLIP_ATTENTION_CLASSES = {
"eager": SiglipAttention,
- "flash_attention_2": SiglipFlashAttention2,
+ "flash_attention_2": SiglipFlashAttention,
+ "flash_attention_3": SiglipFlashAttention,
"sdpa": SiglipSdpaAttention,
}
@@ -677,6 +680,7 @@ class SiglipPreTrainedModel(PreTrainedModel):
"SiglipMultiheadAttentionPoolingHead",
]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
@@ -930,6 +934,7 @@ def __init__(self, config: SiglipTextConfig):
self.head = nn.Linear(embed_dim, embed_dim)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
@add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig)
@@ -962,7 +967,7 @@ def forward(
# note: SigLIP's text model does not use a causal mask, unlike the original CLIP model.
# expand attention_mask
- if attention_mask is not None and not self._use_flash_attention_2:
+ if attention_mask is not None and not self._use_flash_attention_2 and not self._use_flash_attention_3:
# [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len]
attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py
index 9c457d869ac5c9..24bc9e6b1d0016 100755
--- a/src/transformers/models/stablelm/modeling_stablelm.py
+++ b/src/transformers/models/stablelm/modeling_stablelm.py
@@ -506,14 +506,14 @@ def forward(
return attn_output, None, past_key_value
-class StableLmFlashAttention2(StableLmAttention):
+class StableLmFlashAttention(StableLmAttention):
"""
StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -521,6 +521,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -608,6 +609,7 @@ def forward(
dropout=dropout_rate,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
is_causal=self.is_causal,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -622,7 +624,8 @@ def forward(
ATTENTION_CLASSES = {
"eager": StableLmAttention,
"sdpa": StableLmSdpaAttention,
- "flash_attention_2": StableLmFlashAttention2,
+ "flash_attention_2": StableLmFlashAttention,
+ "flash_attention_3": StableLmFlashAttention,
}
@@ -746,6 +749,7 @@ class StableLmPreTrainedModel(PreTrainedModel):
_no_split_modules = ["StableLmDecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
_supports_sdpa = True
_supports_quantized_cache = True
@@ -1010,7 +1014,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py
index 89a36fefe77ace..808b76280f8e2b 100644
--- a/src/transformers/models/starcoder2/modeling_starcoder2.py
+++ b/src/transformers/models/starcoder2/modeling_starcoder2.py
@@ -323,14 +323,14 @@ def forward(
return attn_output, attn_weights, past_key_value
-class Starcoder2FlashAttention2(Starcoder2Attention):
+class Starcoder2FlashAttention(Starcoder2Attention):
"""
Starcoder2 flash attention module. This module inherits from `Starcoder2Attention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -338,6 +338,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
# Ignore copy
def forward(
@@ -447,6 +448,7 @@ def forward(
sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
@@ -560,7 +562,8 @@ def forward(
STARCODER2_ATTENTION_CLASSES = {
"eager": Starcoder2Attention,
- "flash_attention_2": Starcoder2FlashAttention2,
+ "flash_attention_2": Starcoder2FlashAttention,
+ "flash_attention_3": Starcoder2FlashAttention,
"sdpa": Starcoder2SdpaAttention,
}
@@ -675,6 +678,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel):
_no_split_modules = ["Starcoder2DecoderLayer"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_quantized_cache = True
@@ -940,7 +944,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py
index d2779fc200f2df..94fd8743640b5b 100755
--- a/src/transformers/models/unispeech/modeling_unispeech.py
+++ b/src/transformers/models/unispeech/modeling_unispeech.py
@@ -586,15 +586,15 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->UniSpeech
-class UniSpeechFlashAttention2(UniSpeechAttention):
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->UniSpeech
+class UniSpeechFlashAttention(UniSpeechAttention):
"""
UniSpeech flash attention module. This module inherits from `UniSpeechAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -602,6 +602,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
@@ -703,6 +704,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
@@ -824,7 +826,8 @@ def forward(
UNISPEECH_ATTENTION_CLASSES = {
"eager": UniSpeechAttention,
"sdpa": UniSpeechSdpaAttention,
- "flash_attention_2": UniSpeechFlashAttention2,
+ "flash_attention_2": UniSpeechFlashAttention,
+ "flash_attention_3": UniSpeechFlashAttention,
}
@@ -972,6 +975,7 @@ def __init__(self, config):
self.layers = nn.ModuleList([UniSpeechEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -988,7 +992,7 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1060,6 +1064,7 @@ def __init__(self, config):
)
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -1076,7 +1081,7 @@ def forward(
# make sure padded tokens are not attended to
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1218,6 +1223,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
index 7bb98434482d35..dce30e76959411 100755
--- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
+++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
@@ -603,15 +603,15 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->UniSpeechSat
-class UniSpeechSatFlashAttention2(UniSpeechSatAttention):
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->UniSpeechSat
+class UniSpeechSatFlashAttention(UniSpeechSatAttention):
"""
UniSpeechSat flash attention module. This module inherits from `UniSpeechSatAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -619,6 +619,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
@@ -720,6 +721,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
@@ -841,7 +843,8 @@ def forward(
UNISPEECHSAT_ATTENTION_CLASSES = {
"eager": UniSpeechSatAttention,
"sdpa": UniSpeechSatSdpaAttention,
- "flash_attention_2": UniSpeechSatFlashAttention2,
+ "flash_attention_2": UniSpeechSatFlashAttention,
+ "flash_attention_3": UniSpeechSatFlashAttention,
}
@@ -989,6 +992,7 @@ def __init__(self, config):
self.layers = nn.ModuleList([UniSpeechSatEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -1005,7 +1009,7 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1077,6 +1081,7 @@ def __init__(self, config):
)
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -1093,7 +1098,7 @@ def forward(
# make sure padded tokens are not attended to
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1235,6 +1240,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py
index 97bc9f5802029a..5d7dd64843c5e6 100644
--- a/src/transformers/models/video_llava/modeling_video_llava.py
+++ b/src/transformers/models/video_llava/modeling_video_llava.py
@@ -127,6 +127,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["VideoLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
def _init_weights(self, module):
diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py
index 55ccd12367401f..384564cd19b7c2 100644
--- a/src/transformers/models/vipllava/modeling_vipllava.py
+++ b/src/transformers/models/vipllava/modeling_vipllava.py
@@ -133,6 +133,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel):
_no_split_modules = ["VipLlavaVisionAttention"]
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_cache_class = True
def _init_weights(self, module):
diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py
index d79936ab2b8420..e8be08be7eef43 100755
--- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py
@@ -64,6 +64,7 @@
if is_flash_attn_2_available():
from ...modeling_flash_attention_utils import _flash_attention_forward
+
logger = logging.get_logger(__name__)
@@ -650,15 +651,15 @@ def forward(
return attn_output, attn_weights_reshaped, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Wav2Vec2
-class Wav2Vec2FlashAttention2(Wav2Vec2Attention):
+# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->Wav2Vec2
+class Wav2Vec2FlashAttention(Wav2Vec2Attention):
"""
Wav2Vec2 flash attention module. This module inherits from `Wav2Vec2Attention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -666,6 +667,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
@@ -767,6 +769,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
@@ -888,7 +891,8 @@ def forward(
WAV2VEC2_ATTENTION_CLASSES = {
"eager": Wav2Vec2Attention,
"sdpa": Wav2Vec2SdpaAttention,
- "flash_attention_2": Wav2Vec2FlashAttention2,
+ "flash_attention_2": Wav2Vec2FlashAttention,
+ "flash_attention_3": Wav2Vec2FlashAttention,
}
@@ -1006,6 +1010,7 @@ def __init__(self, config):
self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -1022,7 +1027,7 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1093,6 +1098,7 @@ def __init__(self, config):
)
self.gradient_checkpointing = False
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -1109,7 +1115,7 @@ def forward(
# make sure padded tokens are not attended to
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
+ if self._use_flash_attention_2 or self._use_flash_attention_3:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
@@ -1331,6 +1337,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
main_input_name = "input_values"
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
def _init_weights(self, module):
diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py
index b10fc258c8ef45..c1b16f38656eea 100644
--- a/src/transformers/models/whisper/modeling_whisper.py
+++ b/src/transformers/models/whisper/modeling_whisper.py
@@ -347,14 +347,14 @@ def forward(
return attn_output, attn_weights, past_key_value
-class WhisperFlashAttention2(WhisperAttention):
+class WhisperFlashAttention(WhisperAttention):
"""
Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
- # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__
+ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -362,6 +362,7 @@ def __init__(self, *args, **kwargs):
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
+ self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3"
def forward(
self,
@@ -459,6 +460,7 @@ def forward(
dropout=self.dropout if self.training else 0.0,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
+ use_flash_attn_3=self._flash_attn_3,
)
attn_output = attn_output.reshape(bsz, tgt_len, -1)
@@ -570,7 +572,8 @@ def forward(
WHISPER_ATTENTION_CLASSES = {
"eager": WhisperAttention,
- "flash_attention_2": WhisperFlashAttention2,
+ "flash_attention_2": WhisperFlashAttention,
+ "flash_attention_3": WhisperFlashAttention,
"sdpa": WhisperSdpaAttention,
}
@@ -770,6 +773,7 @@ class WhisperPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"]
_supports_flash_attn_2 = True
+ _supports_flash_attn_3 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
@@ -1113,6 +1117,7 @@ def __init__(self, config: WhisperConfig):
[WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)]
)
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
+ self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3"
self._use_sdpa = config._attn_implementation == "sdpa"
self.layer_norm = nn.LayerNorm(config.d_model)
@@ -1375,7 +1380,10 @@ def _update_causal_mask(
past_key_values: Cache,
output_attentions: bool,
):
- if self.config._attn_implementation == "flash_attention_2":
+ if (
+ self.config._attn_implementation == "flash_attention_2"
+ or self.config._attn_implementation == "flash_attention_3"
+ ):
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py
index d05db378443e0b..343a8d5b28bfaf 100644
--- a/src/transformers/models/x_clip/modeling_x_clip.py
+++ b/src/transformers/models/x_clip/modeling_x_clip.py
@@ -220,7 +220,7 @@ def forward(
class XCLIPAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
- def __init__(self, config):
+ def __init__(self, config: XCLIPConfig):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py
index 8eda45bd40efb4..95468c4a936721 100644
--- a/src/transformers/testing_utils.py
+++ b/src/transformers/testing_utils.py
@@ -74,6 +74,7 @@
is_faiss_available,
is_fbgemm_gpu_available,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flax_available,
is_fsdp_available,
is_ftfy_available,
@@ -556,6 +557,16 @@ def require_flash_attn(test_case):
return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case)
+def require_flash_attn_3(test_case):
+ """
+ Decorator marking a test that requires Flash Attention 3.
+
+ These tests are skipped when Flash Attention 3 isn't installed.
+
+ """
+ return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case)
+
+
def require_torch_sdpa(test_case):
"""
Decorator marking a test that requires PyTorch's SDPA.
diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py
index 3b33127be4ba53..2faf3b4c56b5a0 100755
--- a/src/transformers/utils/__init__.py
+++ b/src/transformers/utils/__init__.py
@@ -135,6 +135,7 @@
is_faiss_available,
is_fbgemm_gpu_available,
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flash_attn_greater_or_equal,
is_flash_attn_greater_or_equal_2_10,
is_flax_available,
diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py
index 519755489a3373..bec001763422be 100755
--- a/src/transformers/utils/import_utils.py
+++ b/src/transformers/utils/import_utils.py
@@ -62,6 +62,14 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
except ImportError:
# If the package can't be imported, it's not available
package_exists = False
+ elif pkg_name == "flash_attn_interface":
+ try:
+ package = importlib.import_module(pkg_name)
+ package_version = getattr(package, "__version__", "N/A")
+ package_exists = True
+ except ImportError:
+ # If the package can't be imported, it's not available
+ package_exists = False
else:
# For packages other than "torch", don't attempt the fallback and set as not available
package_exists = False
@@ -922,6 +930,16 @@ def is_flash_attn_greater_or_equal(library_version: str):
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version)
+def is_flash_attn_3_available():
+ if not is_flash_attn_2_available():
+ return False
+
+ if not _is_package_available("flash_attn_interface"):
+ return False
+
+ return True
+
+
def is_torchdistx_available():
return _torchdistx_available
diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py
index 9bb8ef33d75998..9c6064cf55a294 100644
--- a/tests/models/bark/test_modeling_bark.py
+++ b/tests/models/bark/test_modeling_bark.py
@@ -35,6 +35,7 @@
)
from transformers.testing_utils import (
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_fp16,
require_torch_gpu,
@@ -981,6 +982,63 @@ def test_flash_attn_2_inference_equivalence(self):
model.train()
_ = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(reason="Model does not support flash_attention_3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ dummy_input = inputs_dict["input_ids"][:1]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ dummy_attention_mask = dummy_attention_mask[:1]
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+
+ outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
+
+ logits = outputs.hidden_states[-1]
+ logits_fa = outputs_fa.hidden_states[-1]
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ other_inputs = {"output_hidden_states": True}
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
+ outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
+
+ logits = outputs.hidden_states[-1]
+ logits_fa = outputs_fa.hidden_states[-1]
+
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
+
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -1039,6 +1097,64 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(reason="Model does not support flash_attention_3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.bfloat16,
+ )
+ model.to(torch_device)
+
+ dummy_input = inputs_dict["input_ids"][:1]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ dummy_attention_mask = dummy_attention_mask[:1]
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True)
+
+ logits = outputs.hidden_states[-1]
+ logits_fa = outputs_fa.hidden_states[-1]
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
+ outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs)
+
+ logits = outputs.hidden_states[-1]
+ logits_fa = outputs_fa.hidden_states[-1]
+
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+
@require_torch
class BarkModelIntegrationTests(unittest.TestCase):
diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py
index 00e3ad40a57652..f3906191e2c87a 100644
--- a/tests/models/chameleon/test_modeling_chameleon.py
+++ b/tests/models/chameleon/test_modeling_chameleon.py
@@ -24,6 +24,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_gpu,
@@ -366,6 +367,43 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertListEqual(output_native, output_fa_2)
+ @require_flash_attn_3
+ @require_read_token
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ model = ChameleonForConditionalGeneration.from_pretrained(
+ "facebook/chameleon-7b",
+ load_in_4bit=True,
+ device_map={"": 0},
+ )
+
+ processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b")
+ texts = ["hi", "Hello this is a very long sentence"]
+
+ processor.tokenizer.padding_side = "right"
+
+ inputs = processor(texts, return_tensors="pt", padding=True).to(0)
+
+ output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_native = processor.tokenizer.batch_decode(output_native)
+
+ model = ChameleonForConditionalGeneration.from_pretrained(
+ "facebook/chameleon-7b",
+ load_in_4bit=True,
+ attn_implementation="flash_attention_3",
+ )
+
+ output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_fa_3 = processor.tokenizer.batch_decode(output_fa_3)
+
+ self.assertListEqual(output_native, output_fa_3)
+
@unittest.skip("Chameleon forces some token ids to be -inf!")
def test_batching_equivalence(self):
pass
diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py
index 88824756a6fb54..235fc38e3fa534 100644
--- a/tests/models/clip/test_modeling_clip.py
+++ b/tests/models/clip/test_modeling_clip.py
@@ -31,6 +31,7 @@
is_flax_available,
is_pt_flax_cross_test,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_sdpa,
@@ -1022,6 +1023,45 @@ def test_flash_attn_2_inference_equivalence(self):
f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
+ dummy_input_ids = inputs_dict["input_ids"]
+
+ outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(
+ pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
+ )
+
+ self.assertTrue(
+ torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
+ f"Image logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
+ )
+ self.assertTrue(
+ torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
+ f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
+ )
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -1073,6 +1113,57 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
f"Text logits max diff: {torch.max(torch.abs(logits_per_text_eager - logits_per_text_sdpa))}",
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager"
+ )
+ model.to(torch_device)
+
+ dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
+ dummy_input_ids = inputs_dict["input_ids"]
+ dummy_pixel_mask = inputs_dict["attention_mask"]
+
+ # right padding
+ dummy_pixel_mask[:] = 1
+ dummy_pixel_mask[:, -1:] = 0
+
+ outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(
+ pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
+ )
+
+ logits_per_image_eager = outputs.logits_per_image[:, :-1]
+ logits_per_text_eager = outputs.logits_per_text[:, :-1]
+
+ logits_per_image_sdpa = outputs_fa.logits_per_image[:, :-1]
+ logits_per_text_sdpa = outputs_fa.logits_per_text[:, :-1]
+
+ self.assertTrue(
+ torch.allclose(logits_per_image_eager, logits_per_image_sdpa, atol=4e-2, rtol=4e-2),
+ f"Image logits max diff: {torch.max(torch.abs(logits_per_image_eager - logits_per_image_sdpa))}",
+ )
+ self.assertTrue(
+ torch.allclose(logits_per_text_eager, logits_per_text_sdpa, atol=4e-2, rtol=4e-2),
+ f"Text logits max diff: {torch.max(torch.abs(logits_per_text_eager - logits_per_text_sdpa))}",
+ )
+
class CLIPForImageClassificationModelTester(CLIPModelTester):
def __init__(self, parent):
diff --git a/tests/models/distilbert/test_modeling_distilbert.py b/tests/models/distilbert/test_modeling_distilbert.py
index 3a74a1557cf9ba..00c5eb2a0e1ec0 100644
--- a/tests/models/distilbert/test_modeling_distilbert.py
+++ b/tests/models/distilbert/test_modeling_distilbert.py
@@ -19,7 +19,14 @@
import pytest
from transformers import DistilBertConfig, is_torch_available
-from transformers.testing_utils import require_flash_attn, require_torch, require_torch_accelerator, slow, torch_device
+from transformers.testing_utils import (
+ require_flash_attn,
+ require_flash_attn_3,
+ require_torch,
+ require_torch_accelerator,
+ slow,
+ torch_device,
+)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
@@ -348,6 +355,58 @@ def test_flash_attn_2_inference_equivalence(self):
self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
+ # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
+ @require_flash_attn_3
+ @require_torch_accelerator
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence(self):
+ import torch
+
+ for model_class in self.all_model_classes:
+ dummy_input = torch.LongTensor(
+ [
+ [1, 2, 3, 4],
+ [1, 2, 8, 9],
+ [1, 2, 11, 12],
+ [1, 2, 13, 14],
+ ]
+ ).to(torch_device)
+ dummy_attention_mask = torch.LongTensor(
+ [
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ ]
+ ).to(torch_device)
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
+ logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
+
+ self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
+
+ output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
+ logits_fa = output_fa.hidden_states[-1]
+
+ output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
+ logits = output.hidden_states[-1]
+
+ self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2))
+
# Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
@require_flash_attn
@require_torch_accelerator
@@ -403,6 +462,61 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2))
+ # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test.
+ @require_flash_attn_3
+ @require_torch_accelerator
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ import torch
+
+ for model_class in self.all_model_classes:
+ dummy_input = torch.LongTensor(
+ [
+ [1, 2, 3, 4],
+ [1, 2, 8, 9],
+ [1, 2, 11, 12],
+ [1, 2, 13, 14],
+ ]
+ ).to(torch_device)
+ dummy_attention_mask = torch.LongTensor(
+ [
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ [0, 1, 1, 1],
+ ]
+ ).to(torch_device)
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.bfloat16,
+ )
+ model.to(torch_device)
+
+ logits = model(dummy_input, output_hidden_states=True).hidden_states[-1]
+ logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1]
+
+ self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2))
+
+ output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
+ logits_fa = output_fa.hidden_states[-1]
+
+ output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True)
+ logits = output.hidden_states[-1]
+
+ self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2))
+
@require_torch
class DistilBertModelIntergrationTest(unittest.TestCase):
diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py
index a02541d585447c..fb4a11ced8d53c 100644
--- a/tests/models/gemma/test_modeling_gemma.py
+++ b/tests/models/gemma/test_modeling_gemma.py
@@ -25,6 +25,7 @@
is_flaky,
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_accelerator,
@@ -453,6 +454,51 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Gemma apparently does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -460,6 +506,13 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="Gemma flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest(reason="Gemma flash attention 3 does not support right padding")
+
@require_torch_sdpa
@require_torch_accelerator
@slow
@@ -526,6 +579,40 @@ def test_flash_attn_2_equivalence(self):
# gemma flash attention 2 needs a high tolerance
assert torch.allclose(logits_fa, logits, atol=3e-3)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @is_flaky()
+ @slow
+ def test_flash_attn_3_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(reason="Model does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ dummy_input = dummy_input.to(torch_device)
+ outputs = model(dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = outputs.hidden_states[-1]
+ logits_fa = outputs_fa.hidden_states[-1]
+
+ # gemma flash attention 3 needs a high tolerance
+ assert torch.allclose(logits_fa, logits, atol=3e-3)
+
@slow
@require_torch_accelerator
@@ -652,6 +739,29 @@ def test_model_2b_flash_attn(self):
self.assertEqual(output_text, EXPECTED_TEXTS)
+ @require_flash_attn_3
+ @require_read_token
+ @pytest.mark.flash_attn_3_test
+ def test_model_2b_flash_attn_fa3(self):
+ model_id = "google/gemma-2b"
+ EXPECTED_TEXTS = [
+ "Hello I am doing a project on the 1990s and I need to know what the most popular music",
+ "Hi today I am going to share with you a very easy and simple recipe of Kaju Kat",
+ ]
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model.to(torch_device)
+
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
+
+ output = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
+
+ self.assertEqual(output_text, EXPECTED_TEXTS)
+
@require_bitsandbytes
@require_read_token
def test_model_2b_4bit(self):
diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py
index 4e7b3553460f89..75a09f1a83e73d 100644
--- a/tests/models/gemma2/test_modeling_gemma2.py
+++ b/tests/models/gemma2/test_modeling_gemma2.py
@@ -22,6 +22,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline
from transformers.testing_utils import (
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_gpu,
@@ -306,3 +307,27 @@ def test_model_9b_flash_attn(self):
output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
self.assertEqual(output_text, EXPECTED_TEXTS)
+
+ @require_read_token
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_model_9b_flash_attn_3(self):
+ # See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context
+ model_id = "google/gemma-2-9b"
+ EXPECTED_TEXTS = [
+ 'Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few',
+ "Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the"
+ ] # fmt: skip
+
+ model = AutoModelForCausalLM.from_pretrained(
+ model_id, attn_implementation="flash_attention_3", torch_dtype="float16"
+ ).to(torch_device)
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device)
+
+ output = model.generate(**inputs, max_new_tokens=100, do_sample=False)
+ output_text = tokenizer.batch_decode(output, skip_special_tokens=False)
+
+ self.assertEqual(output_text, EXPECTED_TEXTS)
diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py
index 3f96c20ab2dbd9..30d3c4d1b9bab3 100644
--- a/tests/models/gpt2/test_modeling_gpt2.py
+++ b/tests/models/gpt2/test_modeling_gpt2.py
@@ -25,6 +25,7 @@
from transformers.testing_utils import (
backend_empty_cache,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
slow,
@@ -939,3 +940,40 @@ def test_flash_attn_2_generate_padding_left(self):
self.assertListEqual(output_native, output_fa_2)
self.assertListEqual(output_native, expected_output)
+
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_left(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ model = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=torch.float16).to(0)
+
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
+
+ texts = ["hi", "Hello this is a very long sentence"]
+
+ tokenizer.padding_side = "left"
+ tokenizer.pad_token = tokenizer.eos_token
+
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
+
+ output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_native = tokenizer.batch_decode(output_native)
+
+ model = GPT2LMHeadModel.from_pretrained(
+ "gpt2", device_map={"": 0}, attn_implementation="flash_attention_3", torch_dtype=torch.float16
+ )
+
+ output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_fa_3 = tokenizer.batch_decode(output_fa_3)
+
+ expected_output = [
+ "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>hi, who was born in the city of Kolkata, was a member of the Kolkata",
+ "Hello this is a very long sentence. I'm sorry. I'm sorry. I'm sorry. I'm sorry. I'm sorry",
+ ]
+
+ self.assertListEqual(output_native, output_fa_3)
+ self.assertListEqual(output_native, expected_output)
diff --git a/tests/models/gptj/test_modeling_gptj.py b/tests/models/gptj/test_modeling_gptj.py
index 71c121dbaa5a9e..19757d08dcf1bf 100644
--- a/tests/models/gptj/test_modeling_gptj.py
+++ b/tests/models/gptj/test_modeling_gptj.py
@@ -23,6 +23,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
slow,
@@ -536,6 +537,44 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertListEqual(expected_outputs, output_fa_2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b")
+
+ texts = ["hi", "Hello this is a very long sentence"]
+ expected_outputs = [
+ "hi<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Q: I have a question about the new version of the game. I have a question about the",
+ "Hello this is a very long sentence.\n\nA:\n\nI think the best way to understand this is to think of it",
+ ]
+
+ tokenizer.padding_side = "right"
+ tokenizer.pad_token = tokenizer.eos_token
+
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
+
+ quantization_config = BitsAndBytesConfig(load_in_4bit=True)
+
+ model = GPTJForCausalLM.from_pretrained(
+ "EleutherAI/gpt-j-6b",
+ device_map={"": 0},
+ attn_implementation="flash_attention_3",
+ revision="float16",
+ torch_dtype=torch.float16,
+ quantization_config=quantization_config,
+ )
+
+ output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_fa_3 = tokenizer.batch_decode(output_fa_3)
+
+ self.assertListEqual(expected_outputs, output_fa_3)
+
@require_torch
class GPTJModelLanguageGenerationTest(unittest.TestCase):
diff --git a/tests/models/granite/test_modeling_granite.py b/tests/models/granite/test_modeling_granite.py
index 0f4d7640a1bb7d..198d53e52b5f00 100644
--- a/tests/models/granite/test_modeling_granite.py
+++ b/tests/models/granite/test_modeling_granite.py
@@ -24,6 +24,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_gpu,
@@ -463,6 +464,46 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertListEqual(output_native, output_fa_2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_3_test
+ @require_read_token
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ model = GraniteForCausalLM.from_pretrained(
+ "ibm/PowerLM-3b",
+ load_in_4bit=True,
+ device_map={"": 0},
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained("ibm/PowerLM-3b")
+
+ texts = ["hi", "Hello this is a very long sentence"]
+
+ tokenizer.padding_side = "right"
+ tokenizer.pad_token = tokenizer.eos_token
+
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
+
+ output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_native = tokenizer.batch_decode(output_native)
+
+ model = GraniteForCausalLM.from_pretrained(
+ "ibm/PowerLM-3b",
+ load_in_4bit=True,
+ device_map={"": 0},
+ attn_implementation="flash_attention_3",
+ )
+
+ output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_fa_3 = tokenizer.batch_decode(output_fa_3)
+
+ self.assertListEqual(output_native, output_fa_3)
+
@require_flash_attn
@require_torch_gpu
@slow
diff --git a/tests/models/granitemoe/test_modeling_granitemoe.py b/tests/models/granitemoe/test_modeling_granitemoe.py
index 158259ed5fb4c0..23cf45e93e8e60 100644
--- a/tests/models/granitemoe/test_modeling_granitemoe.py
+++ b/tests/models/granitemoe/test_modeling_granitemoe.py
@@ -24,6 +24,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_gpu,
@@ -462,6 +463,46 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertListEqual(output_native, output_fa_2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_3_test
+ @require_read_token
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ model = GraniteMoeForCausalLM.from_pretrained(
+ "ibm-granite/granitemoe-3b",
+ load_in_4bit=True,
+ device_map={"": 0},
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained("ibm-granite/granitemoe-3b")
+
+ texts = ["hi", "Hello this is a very long sentence"]
+
+ tokenizer.padding_side = "right"
+ tokenizer.pad_token = tokenizer.eos_token
+
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
+
+ output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_native = tokenizer.batch_decode(output_native)
+
+ model = GraniteMoeForCausalLM.from_pretrained(
+ "ibm-granite/granitemoe-3b",
+ load_in_4bit=True,
+ device_map={"": 0},
+ attn_implementation="flash_attention_3",
+ )
+
+ output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_fa_3 = tokenizer.batch_decode(output_fa_3)
+
+ self.assertListEqual(output_native, output_fa_3)
+
@require_flash_attn
@require_torch_gpu
@slow
diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py
index e02c5b4c9f09c6..51f8604aad17b2 100644
--- a/tests/models/idefics2/test_modeling_idefics2.py
+++ b/tests/models/idefics2/test_modeling_idefics2.py
@@ -32,6 +32,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
@@ -608,3 +609,37 @@ def test_flash_attn_2_eager_equivalence(self):
)
self.assertEqual(generated_texts_eager[0], generated_texts_flash_attention_2[0])
+
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ def test_flash_attn_3_eager_equivalence(self):
+ # Create inputs
+ text = "In this image, we see"
+ images = self.image1
+ inputs = self.processor(text=text, images=images, return_tensors="pt", padding=True)
+ inputs.to(torch_device)
+
+ # Eager model
+ model_eager = Idefics2ForConditionalGeneration.from_pretrained(
+ "HuggingFaceM4/idefics2-8b-base",
+ attn_implementation="eager",
+ load_in_4bit=True,
+ )
+ generated_ids_eager = model_eager.generate(**inputs, max_new_tokens=10)
+ generated_texts_eager = self.processor.batch_decode(generated_ids_eager, skip_special_tokens=True)
+
+ del model_eager
+
+ # Flash Attention 3 model
+ model_flash_attention_3 = Idefics2ForConditionalGeneration.from_pretrained(
+ "HuggingFaceM4/idefics2-8b-base",
+ attn_implementation="flash_attention_3",
+ load_in_4bit=True,
+ )
+ generated_ids_flash_attention_3 = model_flash_attention_3.generate(**inputs, max_new_tokens=10)
+ generated_texts_flash_attention_3 = self.processor.batch_decode(
+ generated_ids_flash_attention_3, skip_special_tokens=True
+ )
+
+ self.assertEqual(generated_texts_eager[0], generated_texts_flash_attention_3[0])
diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py
index 6e1a2cf2cf9c44..28fa0e8749c827 100644
--- a/tests/models/jamba/test_modeling_jamba.py
+++ b/tests/models/jamba/test_modeling_jamba.py
@@ -25,6 +25,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
slow,
@@ -539,6 +540,45 @@ def test_flash_attn_2_fp32_ln(self):
# with attention mask
_ = model(dummy_input, attention_mask=dummy_attention_mask)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_fp32_ln(self):
+ r"""
+ Overriding the test_flash_attn_3_fp32_ln test as the Jamba model, like Mixtral, doesn't support
+ right padding + use cache with FA3
+ """
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_input = inputs_dict[model.main_input_name]
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Jamba does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ load_in_4bit=True,
+ )
+
+ for _, param in model.named_parameters():
+ # upcast only layer norms
+ if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
+ param.data = param.data.to(torch.float32)
+
+ _ = model(dummy_input)
+ # with attention mask
+ _ = model(dummy_input, attention_mask=dummy_attention_mask)
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -577,6 +617,44 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ r"""
+ Overriding the test_flash_attn_3_generate_padding_right test as the Jamba model, like Mixtral, doesn't support
+ right padding + use cache with FA3
+ """
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -626,6 +704,55 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ r"""
+ Overriding the test_flash_attn_3_generate_use_cache test as the Jamba model, like Mixtral, doesn't support
+ right padding + use cache with FA3
+ """
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Jamba does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -637,6 +764,17 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
"""
self.skipTest(reason="Jamba flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ r"""
+ Overriding the test_flash_attn_3_inference_padding_right test as the Jamba model, like Mixtral, doesn't support
+ right padding + use cache with FA3
+ """
+ self.skipTest(reason="Jamba flash attention does not support right padding")
+
@unittest.skip(reason="Jamba has its own special cache type")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
diff --git a/tests/models/jetmoe/test_modeling_jetmoe.py b/tests/models/jetmoe/test_modeling_jetmoe.py
index 50fd7a27e1e6d1..238d2d7219312f 100644
--- a/tests/models/jetmoe/test_modeling_jetmoe.py
+++ b/tests/models/jetmoe/test_modeling_jetmoe.py
@@ -26,6 +26,7 @@
backend_empty_cache,
is_flaky,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_sdpa,
@@ -420,6 +421,40 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -465,6 +500,51 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: JetMoe apparently does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -472,6 +552,13 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="JetMoe flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest(reason="JetMoe flash attention does not support right padding")
+
@require_torch
class JetMoeIntegrationTest(unittest.TestCase):
diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py
index a21665c822f2f9..572d1c6c62a480 100644
--- a/tests/models/llama/test_modeling_llama.py
+++ b/tests/models/llama/test_modeling_llama.py
@@ -27,6 +27,7 @@
backend_empty_cache,
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_accelerator,
@@ -619,6 +620,43 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertListEqual(output_native, output_fa_2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_3_test
+ @require_read_token
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ model = LlamaForCausalLM.from_pretrained(
+ "meta-llama/Llama-2-7b-hf",
+ load_in_4bit=True,
+ device_map={"": 0},
+ )
+
+ tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+
+ texts = ["hi", "Hello this is a very long sentence"]
+
+ tokenizer.padding_side = "right"
+ tokenizer.pad_token = tokenizer.eos_token
+
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
+
+ output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_native = tokenizer.batch_decode(output_native)
+
+ model = LlamaForCausalLM.from_pretrained(
+ "meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_3"
+ )
+
+ output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_fa_3 = tokenizer.batch_decode(output_fa_3)
+
+ self.assertListEqual(output_native, output_fa_3)
+
@require_flash_attn
@require_torch_gpu
@slow
diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py
index a29a9c8a9ec0dc..35701f0c0eb9d8 100644
--- a/tests/models/m2m_100/test_modeling_m2m_100.py
+++ b/tests/models/m2m_100/test_modeling_m2m_100.py
@@ -23,6 +23,7 @@
from transformers import M2M100Config, is_torch_available
from transformers.testing_utils import (
require_flash_attn,
+ require_flash_attn_3,
require_sentencepiece,
require_tokenizers,
require_torch,
@@ -465,3 +466,48 @@ def test_flash_attn_2_seq_to_seq_generation(self):
hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
)
assert generated == expected_en
+
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_seq_to_seq_generation(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ model = M2M100ForConditionalGeneration.from_pretrained(
+ "facebook/m2m100_418M", attn_implementation="flash_attention_3"
+ ).to(torch_device)
+
+ tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en")
+
+ src_fr = [
+ "L'affaire NSA souligne l'absence totale de débat sur le renseignement",
+ "Selon moi, il y a deux niveaux de réponse de la part du gouvernement français.",
+ "Lorsque François Hollande téléphone à Barack Obama ou quand le ministre des affaires étrangères Laurent"
+ " Fabius convoque l'ambassadeur des Etats-Unis, ils réagissent à une vraie découverte, qui est celle de"
+ " l'ampleur de la surveillance américaine sur l'ensemble des communications en France.",
+ ]
+
+ # The below article tests that we don't add any hypotheses outside of the top n_beams
+ dct = tokenizer(src_fr, padding=True, return_tensors="pt")
+
+ hypotheses_batch = model.generate(
+ input_ids=dct["input_ids"].to(torch_device),
+ attention_mask=dct["attention_mask"].to(torch_device),
+ num_beams=5,
+ forced_bos_token_id=tokenizer.get_lang_id("en"),
+ )
+
+ expected_en = [
+ "The NSA case highlights the total absence of intelligence debate",
+ "I think there are two levels of response from the French government.",
+ "When François Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S."
+ " Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all"
+ " communications in France.",
+ ]
+
+ generated = tokenizer.batch_decode(
+ hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True
+ )
+ assert generated == expected_en
diff --git a/tests/models/mimi/test_modeling_mimi.py b/tests/models/mimi/test_modeling_mimi.py
index dd0f77421be728..137d64279b4695 100644
--- a/tests/models/mimi/test_modeling_mimi.py
+++ b/tests/models/mimi/test_modeling_mimi.py
@@ -30,6 +30,7 @@
is_flaky,
is_torch_available,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_sdpa,
@@ -738,10 +739,46 @@ def test_flash_attn_2_inference_equivalence(self):
assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ @is_flaky()
+ def test_flash_attn_3_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model.main_input_name][:1]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ outputs = model(dummy_input)
+ outputs_fa = model_fa(dummy_input)
+
+ logits = outputs[1]
+ logits_fa = outputs_fa[1]
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
@unittest.skip(reason="The MimiModel does not support right padding")
def test_flash_attn_2_inference_equivalence_right_padding(self):
pass
+ @unittest.skip(reason="The MimiModel does not support right padding")
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ pass
+
@unittest.skip(reason="The MimiModel does not have support dynamic compile yet")
def test_sdpa_can_compile_dynamic(self):
pass
diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py
index 0730f8ba444140..a64eac8eee37f1 100644
--- a/tests/models/mistral/test_modeling_mistral.py
+++ b/tests/models/mistral/test_modeling_mistral.py
@@ -27,6 +27,7 @@
is_flaky,
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_accelerator,
@@ -440,6 +441,40 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -485,6 +520,51 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Mistral apparently does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -492,6 +572,13 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="Mistral flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest(reason="Mistral flash attention does not support right padding")
+
@require_torch_gpu
class MistralIntegrationTest(unittest.TestCase):
@@ -602,6 +689,31 @@ def test_model_7b_long_prompt(self):
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
+ @require_flash_attn_3
+ @require_bitsandbytes
+ @slow
+ @pytest.mark.flash_attn_3_test
+ def test_model_7b_long_prompt_fa3(self):
+ EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
+ # An input with 4097 tokens that is above the size of the sliding window
+ input_ids = [1] + [306, 338] * 2048
+ model = MistralForCausalLM.from_pretrained(
+ "mistralai/Mistral-7B-v0.1",
+ device_map={"": torch_device},
+ load_in_4bit=True,
+ attn_implementation="flash_attention_3",
+ )
+ input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
+ generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
+ self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
+
+ # Assisted generation
+ assistant_model = model
+ assistant_model.generation_config.num_assistant_tokens = 2
+ assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
+ generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
+ self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
+
@slow
@require_torch_sdpa
def test_model_7b_long_prompt_sdpa(self):
diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py
index db9641e3dcb2a9..6e074dedbe134e 100644
--- a/tests/models/mixtral/test_modeling_mixtral.py
+++ b/tests/models/mixtral/test_modeling_mixtral.py
@@ -23,6 +23,7 @@
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_sdpa,
@@ -440,6 +441,40 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -485,6 +520,51 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Mixtral apparently does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -492,6 +572,13 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="Mixtral flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest(reason="Mixtral flash attention does not support right padding")
+
# Ignore copy
def test_load_balancing_loss(self):
r"""
diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py
index a385a18b91c5d5..50ed7321709bc6 100644
--- a/tests/models/musicgen/test_modeling_musicgen.py
+++ b/tests/models/musicgen/test_modeling_musicgen.py
@@ -35,6 +35,7 @@
from transformers.testing_utils import (
is_torch_available,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_accelerator,
require_torch_fp16,
@@ -404,6 +405,86 @@ def test_flash_attn_2_inference_equivalence(self):
model.train()
_ = model_fa(dummy_input, **other_inputs)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence
+ def test_flash_attn_3_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+
+ # Ignore copy
+ outputs = model(dummy_input, output_hidden_states=True)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ # Ignore copy
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
+
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(dummy_input, **other_inputs)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -483,6 +564,85 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence_right_padding
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ if model.config.is_encoder_decoder:
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
+
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ else:
+ outputs = model(dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+ # Ignore copy
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -529,6 +689,52 @@ def test_flash_attn_2_generate_left_padding(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_left_padding
+ def test_flash_attn_3_generate_left_padding(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # make sure we do left padding
+ dummy_attention_mask[:, :-1] = 0
+ dummy_attention_mask[:, -1:] = 1
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -575,6 +781,52 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_padding_right
+ def test_flash_attn_3_generate_padding_right(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # make sure we do right padding
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -621,6 +873,52 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_use_cache
+ def test_flash_attn_3_generate_use_cache(self):
+ max_new_tokens = 30
+
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@@ -1599,36 +1897,121 @@ def test_greedy_generate_stereo_outputs(self):
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.audio_channels = 2
- model = model_class(config).to(torch_device).eval()
- output_generate = self._greedy_generate(
- model=model,
- input_ids=input_ids.to(torch_device),
- attention_mask=attention_mask.to(torch_device),
- output_scores=True,
- output_hidden_states=True,
- output_attentions=True,
- return_dict_in_generate=True,
- )
+ model = model_class(config).to(torch_device).eval()
+ output_generate = self._greedy_generate(
+ model=model,
+ input_ids=input_ids.to(torch_device),
+ attention_mask=attention_mask.to(torch_device),
+ output_scores=True,
+ output_hidden_states=True,
+ output_attentions=True,
+ return_dict_in_generate=True,
+ )
+
+ self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
+
+ self.assertNotIn(config.pad_token_id, output_generate)
+
+ @unittest.skip(
+ reason="MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model"
+ )
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ @require_flash_attn
+ @require_torch_gpu
+ @mark.flash_attn_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
+ def test_flash_attn_2_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+
+ # Ignore copy
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
+ # Ignore copy
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+ # Ignore copy
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ # Ignore copy
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+ # Ignore copy
+ outputs = model(dummy_input, **other_inputs)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, **other_inputs)
- self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput)
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
- self.assertNotIn(config.pad_token_id, output_generate)
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
- @unittest.skip(
- reason="MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model"
- )
- def test_save_load_fast_init_from_base(self):
- pass
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(dummy_input, **other_inputs)
- @require_flash_attn
+ @require_flash_attn_3
@require_torch_gpu
- @mark.flash_attn_test
+ @mark.flash_attn_3_test
@slow
- # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
- def test_flash_attn_2_inference_equivalence(self):
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence
+ def test_flash_attn_3_inference_equivalence(self):
for model_class in self.all_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@@ -1636,7 +2019,7 @@ def test_flash_attn_2_inference_equivalence(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
)
model_fa.to(torch_device)
@@ -1787,6 +2170,88 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence_right_padding
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ # Ignore copy
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
+ # Ignore copy
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ # Ignore copy
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ # Ignore copy
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+ # Ignore copy
+ outputs = model(dummy_input, **other_inputs)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -1836,6 +2301,55 @@ def test_flash_attn_2_generate_left_padding(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_left_padding
+ def test_flash_attn_3_generate_left_padding(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask")
+ if dummy_attention_mask is None:
+ dummy_attention_mask = torch.ones_like(dummy_input)
+
+ # make sure we do left padding
+ dummy_attention_mask[:, :-1] = 0
+ dummy_attention_mask[:, -1:] = 1
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -1884,6 +2398,54 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_padding_right
+ def test_flash_attn_3_generate_padding_right(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask")
+ if dummy_attention_mask is None:
+ dummy_attention_mask = torch.ones_like(dummy_input)
+ # make sure we do right padding
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -1930,6 +2492,52 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_use_cache
+ def test_flash_attn_3_generate_use_cache(self):
+ max_new_tokens = 30
+
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
index e8584e238d3cd9..1a3956e3d2eabc 100644
--- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
+++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
@@ -35,6 +35,7 @@
is_torch_available,
is_torchaudio_available,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_accelerator,
require_torch_fp16,
@@ -406,6 +407,86 @@ def test_flash_attn_2_inference_equivalence(self):
model.train()
_ = model_fa(dummy_input, **other_inputs)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_3_inference_equivalence
+ def test_flash_attn_3_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+
+ # Ignore copy
+ outputs = model(dummy_input, output_hidden_states=True)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ # Ignore copy
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
+
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(dummy_input, **other_inputs)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -485,6 +566,85 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_3_inference_equivalence_right_padding
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ if model.config.is_encoder_decoder:
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
+
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ else:
+ outputs = model(dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+ # Ignore copy
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -531,6 +691,52 @@ def test_flash_attn_2_generate_left_padding(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_left_padding
+ def test_flash_attn_3_generate_left_padding(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # make sure we do left padding
+ dummy_attention_mask[:, :-1] = 0
+ dummy_attention_mask[:, -1:] = 1
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -577,6 +783,52 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_padding_right
+ def test_flash_attn_3_generate_padding_right(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # make sure we do right padding
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -623,6 +875,52 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_3_generate_use_cache
+ def test_flash_attn_3_generate_use_cache(self):
+ max_new_tokens = 30
+
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
@@ -1583,36 +1881,121 @@ def test_greedy_generate_stereo_outputs(self):
config, input_ids, attention_mask = self._get_input_ids_and_config()
config.audio_channels = 2
- model = model_class(config).to(torch_device).eval()
- output_generate = self._greedy_generate(
- model=model,
- input_ids=input_ids.to(torch_device),
- attention_mask=attention_mask.to(torch_device),
- output_scores=True,
- output_hidden_states=True,
- output_attentions=True,
- return_dict_in_generate=True,
- )
+ model = model_class(config).to(torch_device).eval()
+ output_generate = self._greedy_generate(
+ model=model,
+ input_ids=input_ids.to(torch_device),
+ attention_mask=attention_mask.to(torch_device),
+ output_scores=True,
+ output_hidden_states=True,
+ output_attentions=True,
+ return_dict_in_generate=True,
+ )
+
+ self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
+
+ self.assertNotIn(config.pad_token_id, output_generate)
+
+ @unittest.skip(
+ reason="MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model"
+ )
+ def test_save_load_fast_init_from_base(self):
+ pass
+
+ @require_flash_attn
+ @require_torch_gpu
+ @mark.flash_attn_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
+ def test_flash_attn_2_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+
+ # Ignore copy
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
+ # Ignore copy
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+ # Ignore copy
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ # Ignore copy
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+ # Ignore copy
+ outputs = model(dummy_input, **other_inputs)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, **other_inputs)
- self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput)
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
- self.assertNotIn(config.pad_token_id, output_generate)
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
- @unittest.skip(
- reason="MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model"
- )
- def test_save_load_fast_init_from_base(self):
- pass
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(dummy_input, **other_inputs)
- @require_flash_attn
+ @require_flash_attn_3
@require_torch_gpu
- @mark.flash_attn_test
+ @mark.flash_attn_3_test
@slow
- # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence
- def test_flash_attn_2_inference_equivalence(self):
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence
+ def test_flash_attn_3_inference_equivalence(self):
for model_class in self.all_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@@ -1620,7 +2003,7 @@ def test_flash_attn_2_inference_equivalence(self):
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_fa = model_class.from_pretrained(
- tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
)
model_fa.to(torch_device)
@@ -1771,6 +2154,88 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence_right_padding
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ # Ignore copy
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ # Ignore copy
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ # Ignore copy
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)
+ # Ignore copy
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ # Ignore copy
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ # Ignore copy
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+ # Ignore copy
+ outputs = model(dummy_input, **other_inputs)
+ # Ignore copy
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -1820,6 +2285,55 @@ def test_flash_attn_2_generate_left_padding(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_left_padding
+ def test_flash_attn_3_generate_left_padding(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask")
+ if dummy_attention_mask is None:
+ dummy_attention_mask = torch.ones_like(dummy_input)
+
+ # make sure we do left padding
+ dummy_attention_mask[:, :-1] = 0
+ dummy_attention_mask[:, -1:] = 1
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -1868,6 +2382,54 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_padding_right
+ def test_flash_attn_3_generate_padding_right(self):
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask")
+ if dummy_attention_mask is None:
+ dummy_attention_mask = torch.ones_like(dummy_input)
+ # make sure we do right padding
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -1914,6 +2476,52 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_use_cache
+ def test_flash_attn_3_generate_use_cache(self):
+ max_new_tokens = 30
+
+ # Ignore copy
+ for model_class in self.greedy_sample_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
diff --git a/tests/models/nemotron/test_modeling_nemotron.py b/tests/models/nemotron/test_modeling_nemotron.py
index 4f8f4cc77fe8d0..450939bd0d36a1 100644
--- a/tests/models/nemotron/test_modeling_nemotron.py
+++ b/tests/models/nemotron/test_modeling_nemotron.py
@@ -25,6 +25,7 @@
from transformers.testing_utils import (
is_flaky,
require_flash_attn,
+ require_flash_attn_3,
require_read_token,
require_torch,
require_torch_gpu,
@@ -178,6 +179,40 @@ def test_flash_attn_2_equivalence(self):
# nemotron flash attention 2 needs a high tolerance
assert torch.allclose(logits_fa, logits, atol=1e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @is_flaky()
+ @slow
+ def test_flash_attn_3_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(reason="Model does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager")
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ dummy_input = dummy_input.to(torch_device)
+ outputs = model(dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = outputs.hidden_states[-1]
+ logits_fa = outputs_fa.hidden_states[-1]
+
+ # nemotron flash attention 3 needs a high tolerance
+ assert torch.allclose(logits_fa, logits, atol=1e-2)
+
@require_torch_gpu
class NemotronIntegrationTest(unittest.TestCase):
diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py
index 95b0b01c0a23d9..fca9c563361edd 100644
--- a/tests/models/phi/test_modeling_phi.py
+++ b/tests/models/phi/test_modeling_phi.py
@@ -24,6 +24,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
slow,
@@ -498,6 +499,43 @@ def test_flash_attn_2_generate_padding_right(self):
self.assertListEqual(output_native, output_fa_2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ @pytest.mark.flash_attn_3_test
+ @slow
+ # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_3_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->microsoft/phi-1
+ def test_flash_attn_3_generate_padding_right(self):
+ """
+ Overwritting the common test as the test is flaky on tiny models
+ """
+ model = PhiForCausalLM.from_pretrained(
+ "microsoft/phi-1",
+ load_in_4bit=True,
+ device_map={"": 0},
+ )
+
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
+
+ texts = ["hi", "Hello this is a very long sentence"]
+
+ tokenizer.padding_side = "right"
+ tokenizer.pad_token = tokenizer.eos_token
+
+ inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0)
+
+ output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_native = tokenizer.batch_decode(output_native)
+
+ model = PhiForCausalLM.from_pretrained(
+ "microsoft/phi-1", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_3"
+ )
+
+ output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False)
+ output_fa_3 = tokenizer.batch_decode(output_fa_3)
+
+ self.assertListEqual(output_native, output_fa_3)
+
@slow
@require_torch
diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py
index 4d6c432f20424d..59745222cec592 100644
--- a/tests/models/qwen2/test_modeling_qwen2.py
+++ b/tests/models/qwen2/test_modeling_qwen2.py
@@ -25,6 +25,7 @@
backend_empty_cache,
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_sdpa,
@@ -450,6 +451,40 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -495,6 +530,51 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Qwen2 apparently does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -502,6 +582,13 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="Qwen2 flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest(reason="Qwen2 flash attention does not support right padding")
+
@require_torch
class Qwen2IntegrationTest(unittest.TestCase):
@@ -571,6 +658,36 @@ def test_model_450m_long_prompt(self):
backend_empty_cache(torch_device)
gc.collect()
+ @require_bitsandbytes
+ @slow
+ @require_flash_attn_3
+ @pytest.mark.flash_attn_3_test
+ def test_model_450m_long_prompt_fav3(self):
+ EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
+ # An input with 4097 tokens that is above the size of the sliding window
+ input_ids = [1] + [306, 338] * 2048
+ model = Qwen2ForCausalLM.from_pretrained(
+ "Qwen/Qwen2-450m-beta",
+ device_map="auto",
+ load_in_4bit=True,
+ attn_implementation="flash_attention_3",
+ )
+ input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
+ generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
+ self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
+
+ # Assisted generation
+ assistant_model = model
+ assistant_model.generation_config.num_assistant_tokens = 2
+ assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
+ generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
+ self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
+
+ del assistant_model
+ del model
+ backend_empty_cache(torch_device)
+ gc.collect()
+
@slow
@require_torch_sdpa
def test_model_450m_long_prompt_sdpa(self):
diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py
index 0425172a6fba4d..15eb85b5bcc351 100644
--- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py
+++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py
@@ -25,6 +25,7 @@
backend_empty_cache,
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_sdpa,
@@ -475,6 +476,40 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -520,6 +555,51 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Qwen2Moe apparently does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -527,6 +607,13 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="Qwen2Moe flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest(reason="Qwen2Moe flash attention does not support right padding")
+
# Ignore copy
def test_load_balancing_loss(self):
r"""
@@ -633,6 +720,36 @@ def test_model_a2_7b_long_prompt(self):
backend_empty_cache(torch_device)
gc.collect()
+ @require_bitsandbytes
+ @slow
+ @require_flash_attn_3
+ @pytest.mark.flash_attn_3_test
+ def test_model_a2_7b_long_prompt_fav3(self):
+ EXPECTED_OUTPUT_TOKEN_IDS = [306, 338]
+ # An input with 4097 tokens that is above the size of the sliding window
+ input_ids = [1] + [306, 338] * 2048
+ model = Qwen2MoeForCausalLM.from_pretrained(
+ "Qwen/Qwen1.5-MoE-A2.7B",
+ device_map="auto",
+ load_in_4bit=True,
+ attn_implementation="flash_attention_3",
+ )
+ input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
+ generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
+ self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
+
+ # Assisted generation
+ assistant_model = model
+ assistant_model.generation_config.num_assistant_tokens = 2
+ assistant_model.generation_config.num_assistant_tokens_schedule = "constant"
+ generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
+ self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist())
+
+ del assistant_model
+ del model
+ backend_empty_cache(torch_device)
+ gc.collect()
+
@slow
@require_torch_sdpa
def test_model_a2_7b_long_prompt_sdpa(self):
diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
index 956243dccebebf..312cd6df4a9d57 100644
--- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
+++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py
@@ -28,6 +28,7 @@
)
from transformers.testing_utils import (
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
slow,
@@ -478,6 +479,38 @@ def test_small_model_integration_test_batch_flashatt2(self):
self.processor.batch_decode(output, skip_special_tokens=True)[1],
)
+ @slow
+ @require_flash_attn_3
+ @require_torch_gpu
+ def test_small_model_integration_test_batch_flashatt3(self):
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
+ "Qwen/Qwen2-VL-7B-Instruct",
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_3",
+ device_map="auto",
+ )
+ text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
+ inputs = self.processor(text=[text, text], images=[self.image, self.image], return_tensors="pt").to(
+ torch_device
+ )
+
+ # it should not matter whether two images are the same size or not
+ output = model.generate(**inputs, max_new_tokens=30)
+
+ EXPECTED_DECODED_TEXT = [
+ "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets",
+ "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets",
+ ]
+
+ self.assertEqual(
+ self.processor.batch_decode(output, skip_special_tokens=True),
+ EXPECTED_DECODED_TEXT,
+ )
+ self.assertEqual(
+ self.processor.batch_decode(output, skip_special_tokens=True)[0],
+ self.processor.batch_decode(output, skip_special_tokens=True)[1],
+ )
+
@slow
@require_flash_attn
@require_torch_gpu
@@ -510,3 +543,36 @@ def test_small_model_integration_test_batch_wo_image_flashatt2(self):
self.processor.batch_decode(output, skip_special_tokens=True),
EXPECTED_DECODED_TEXT,
)
+
+ @slow
+ @require_flash_attn_3
+ @require_torch_gpu
+ def test_small_model_integration_test_batch_wo_image_flashatt3(self):
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
+ "Qwen/Qwen2-VL-7B-Instruct",
+ torch_dtype=torch.bfloat16,
+ attn_implementation="flash_attention_3",
+ device_map="auto",
+ )
+ text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True)
+ messages2 = [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Who are you?"},
+ ]
+ text2 = self.processor.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
+ inputs = self.processor(text=[text, text2], images=[self.image], padding=True, return_tensors="pt").to(
+ torch_device
+ )
+
+ # it should not matter whether two images are the same size or not
+ output = model.generate(**inputs, max_new_tokens=30)
+
+ EXPECTED_DECODED_TEXT = [
+ "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets",
+ "system\nYou are a helpful assistant.\nuser\nWho are you?\nassistant\nI am Qwen, a large language model created by Alibaba Cloud. I am designed to answer a wide range of questions and provide information on various topics",
+ ]
+
+ self.assertEqual(
+ self.processor.batch_decode(output, skip_special_tokens=True),
+ EXPECTED_DECODED_TEXT,
+ )
diff --git a/tests/models/siglip/test_modeling_siglip.py b/tests/models/siglip/test_modeling_siglip.py
index 9d1e3109b313c3..a3c1cfcb66e40f 100644
--- a/tests/models/siglip/test_modeling_siglip.py
+++ b/tests/models/siglip/test_modeling_siglip.py
@@ -28,6 +28,7 @@
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from transformers.testing_utils import (
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
require_torch_sdpa,
@@ -834,12 +835,95 @@ def test_flash_attn_2_inference_equivalence(self):
output_hidden_states=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence(self):
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16)
+ dummy_input_ids = inputs_dict["input_ids"]
+
+ outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(
+ pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True
+ )
+
+ self.assertTrue(
+ torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
+ f"Image logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
+ )
+ self.assertTrue(
+ torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
+ f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
+ )
+
+ # Test with attention mask
+ dummy_attention_mask = inputs_dict["attention_mask"]
+
+ if dummy_attention_mask is not None:
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+
+ outputs = model(
+ pixel_values=dummy_pixel_values,
+ input_ids=dummy_input_ids,
+ attention_mask=dummy_attention_mask,
+ output_hidden_states=True,
+ )
+ outputs_fa = model_fa(
+ pixel_values=dummy_pixel_values,
+ input_ids=dummy_input_ids,
+ attention_mask=dummy_attention_mask,
+ output_hidden_states=True,
+ )
+
+ self.assertTrue(
+ torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2),
+ f"Logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}",
+ )
+ self.assertTrue(
+ torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2),
+ f"Logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}",
+ )
+
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(
+ pixel_values=dummy_pixel_values,
+ input_ids=dummy_input_ids,
+ attention_mask=dummy_attention_mask,
+ output_hidden_states=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest("SigLIP does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest("SigLIP does not support right padding")
+
@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
@require_torch_sdpa
@slow
diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py
index 36cad89bcfdf06..5437330bc96a59 100644
--- a/tests/models/stablelm/test_modeling_stablelm.py
+++ b/tests/models/stablelm/test_modeling_stablelm.py
@@ -24,6 +24,7 @@
is_flaky,
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_sdpa,
slow,
@@ -559,6 +560,24 @@ def test_model_3b_long_prompt(self):
generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-3:].tolist())
+ @require_bitsandbytes
+ @slow
+ @require_flash_attn_3
+ @pytest.mark.flash_attn_3_test
+ def test_model_3b_long_prompt_fav3(self):
+ EXPECTED_OUTPUT_TOKEN_IDS = [3, 3, 3]
+ input_ids = [306, 338] * 2047
+ model = StableLmForCausalLM.from_pretrained(
+ "stabilityai/stablelm-3b-4e1t",
+ device_map="auto",
+ torch_dtype="auto",
+ load_in_4bit=True,
+ attn_implementation="flash_attention_3",
+ )
+ input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
+ generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0)
+ self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-3:].tolist())
+
# Copied from transformers.tests.models.llama.test_modeling_llama.LlamaModelTest.test_eager_matches_sdpa_generate with Llama->StableLm,saibo/llama-1B->stabilityai/stablelm-3b-4e1t
# TODO: @Fxmarty
@is_flaky(max_attempts=3, description="flaky on some models.")
diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py
index c1c7d45d4f18d7..04a5f657f833b0 100644
--- a/tests/models/starcoder2/test_modeling_starcoder2.py
+++ b/tests/models/starcoder2/test_modeling_starcoder2.py
@@ -23,6 +23,7 @@
from transformers.testing_utils import (
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_torch,
require_torch_gpu,
slow,
@@ -431,6 +432,40 @@ def test_flash_attn_2_generate_padding_right(self):
dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
+ import torch
+
+ for model_class in self.all_generative_model_classes:
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device)
+
+ model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ with self.assertRaises(ValueError):
+ _ = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -476,6 +511,51 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ import torch
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # NOTE: Starcoder2 apparently does not support right padding + use_cache with FA3.
+ dummy_attention_mask[:, -1] = 1
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -483,6 +563,13 @@ def test_flash_attn_2_generate_use_cache(self):
def test_flash_attn_2_inference_equivalence_right_padding(self):
self.skipTest(reason="Starcoder2 flash attention does not support right padding")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ self.skipTest(reason="Starcoder2 flash attention does not support right padding")
+
@slow
@require_torch_gpu
@@ -549,6 +636,28 @@ def test_starcoder2_batched_generation_fa2(self):
output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT, output_text)
+ @require_flash_attn_3
+ @pytest.mark.flash_attn_3_test
+ def test_starcoder2_batched_generation_fa3(self):
+ EXPECTED_TEXT = [
+ "Hello my name is Younes and I am a student at the University of Liverpool. I am currently studying for my MSc in Computer Science. I am interested in the field of Machine Learning and I am currently working on",
+ "def hello_world():\n\treturn 'Hello World!'\n\n@app.route('/hello/')\ndef hello_name(name):\n\treturn 'Hello %s!' % name\n\n@app",
+ ]
+ model_id = "bigcode/starcoder2-7b"
+
+ model = Starcoder2ForCausalLM.from_pretrained(
+ model_id, torch_dtype=torch.float16, device_map="auto", attn_implementation="flash_attention_3"
+ )
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ tokenizer.pad_token = tokenizer.eos_token
+
+ text = ["Hello my name is Younes and", "def hello_world():"]
+ inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch_device)
+
+ output = model.generate(**inputs, max_new_tokens=40, do_sample=False)
+ output_text = tokenizer.batch_decode(output, skip_special_tokens=True)
+ self.assertEqual(EXPECTED_TEXT, output_text)
+
@require_bitsandbytes
def test_starcoder2_batched_generation_4bit(self):
EXPECTED_TEXT = [
diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py
index ff7a85218d3a00..1bce68348cc91b 100644
--- a/tests/models/wav2vec2/test_modeling_wav2vec2.py
+++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py
@@ -35,6 +35,7 @@
is_pyctcdecode_available,
is_torchaudio_available,
require_flash_attn,
+ require_flash_attn_3,
require_pyctcdecode,
require_soundfile,
require_torch,
@@ -2023,6 +2024,28 @@ def test_inference_ctc_fa2(self):
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ def test_inference_ctc_fa3(self):
+ model_fa = Wav2Vec2ForCTC.from_pretrained(
+ "facebook/wav2vec2-base-960h", attn_implementation="flash_attention_3", torch_dtype=torch.bfloat16
+ )
+ model_fa.to(torch_device)
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
+ input_speech = self._load_datasamples(1)
+
+ input_values = processor(input_speech, return_tensors="pt").input_values.to(torch_device)
+
+ with torch.no_grad():
+ logits = model_fa(input_values.to(torch.bfloat16)).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ predicted_trans = processor.batch_decode(predicted_ids)
+
+ EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
+ self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -2049,3 +2072,30 @@ def test_inference_ctc_fa2_batched(self):
"sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore",
]
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
+
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ def test_inference_ctc_fa3_batched(self):
+ model_fa = Wav2Vec2ForCTC.from_pretrained(
+ "facebook/wav2vec2-base-960h", attn_implementation="flash_attention_3", torch_dtype=torch.bfloat16
+ )
+ model_fa.to(torch_device)
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
+
+ input_speech = self._load_datasamples(2)
+
+ inputs = processor(input_speech, return_tensors="pt", padding=True, return_attention_mask=True)
+ inputs = inputs.to(torch_device)
+
+ with torch.no_grad():
+ logits = model_fa(inputs.input_values.to(torch.bfloat16), attention_mask=inputs.attention_mask).logits
+
+ predicted_ids = torch.argmax(logits, dim=-1)
+ predicted_trans = processor.batch_decode(predicted_ids)
+
+ EXPECTED_TRANSCRIPTIONS = [
+ "a man said to the universe sir i exist",
+ "sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore",
+ ]
+ self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py
index b4e71ca72e56ed..71f9e9fb9b490e 100644
--- a/tests/models/whisper/test_modeling_whisper.py
+++ b/tests/models/whisper/test_modeling_whisper.py
@@ -34,6 +34,7 @@
is_flaky,
is_pt_flax_cross_test,
require_flash_attn,
+ require_flash_attn_3,
require_non_xpu,
require_torch,
require_torch_accelerator,
@@ -956,6 +957,52 @@ def test_flash_attn_2_inference_equivalence(self):
model.train()
_ = model_fa(dummy_input, decoder_input_ids=decoder_input_ids)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence(self):
+ import torch
+
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(reason="Model does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.bfloat16,
+ )
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model.main_input_name][:1]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
+
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+
+ logits = outputs.decoder_hidden_states[-1]
+ logits_fa = outputs_fa.decoder_hidden_states[-1]
+
+ # whisper FA3 needs very high tolerance
+ assert torch.allclose(logits_fa, logits, atol=4e-1)
+
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(dummy_input, decoder_input_ids=decoder_input_ids)
+
@require_flash_attn
@require_torch_gpu
@pytest.mark.flash_attn_test
@@ -1012,6 +1059,62 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
# whisper FA2 needs very high tolerance
assert torch.allclose(logits_fa[:, -2:], logits[:, -2:], atol=4e-1)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ import torch
+
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(reason="Model does not support flash_attention_3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16)
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model.main_input_name][:1]
+ dummy_input = dummy_input.to(torch.float16)
+
+ decoder_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=dummy_input.device, dtype=torch.long)
+ decoder_attention_mask = torch.tensor(
+ [[0, 0, 0, 1, 1, 1]], device=dummy_input.device, dtype=torch.long
+ )
+
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+
+ logits = outputs.decoder_hidden_states[-1]
+ logits_fa = outputs_fa.decoder_hidden_states[-1]
+
+ # whisper FA3 needs very high tolerance
+ assert torch.allclose(logits_fa, logits, atol=4e-1)
+
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": decoder_attention_mask,
+ "output_hidden_states": True,
+ }
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = outputs.decoder_hidden_states[-1]
+ logits_fa = outputs_fa.decoder_hidden_states[-1]
+
+ # whisper FA3 needs very high tolerance
+ assert torch.allclose(logits_fa[:, -2:], logits[:, -2:], atol=4e-1)
+
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
self.skipTest(reason="test_torchscript is set to False")
@@ -1664,6 +1767,59 @@ def test_flash_attn_2_generate_reuse_cache(self):
past_key_values=past_key_values,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @pytest.mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_reuse_cache(self):
+ max_new_tokens = 2
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name][..., :10]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # run generate once to get filled cache
+ output = model.generate(
+ dummy_input,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ return_dict_in_generate=True,
+ )
+ past_key_values = output.past_key_values
+
+ # Try to continue generation from where we left, given that we have more than 1 new token to process
+ # e.g. this can happen in speculative decoding when feeding candidate tokens back to target model
+ _ = model.generate(
+ dummy_input,
+ decoder_input_ids=output.sequences,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ past_key_values=past_key_values,
+ )
+
def test_labels_sequence_max_length_correct(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index 4d96b229284089..8726f2736156c3 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -76,6 +76,7 @@
require_accelerate,
require_bitsandbytes,
require_flash_attn,
+ require_flash_attn_3,
require_non_xpu,
require_read_token,
require_safetensors,
@@ -3488,6 +3489,34 @@ def test_flash_attn_2_conversion(self):
self.assertTrue(False, "FlashAttention2 modules not found in model")
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_conversion(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_3"
+ ).to(torch_device)
+
+ for _, module in model.named_modules():
+ if "FlashAttention" in module.__class__.__name__:
+ return
+
+ self.assertTrue(False, "FlashAttention3 modules not found in model")
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -3585,6 +3614,103 @@ def test_flash_attn_2_inference_equivalence(self):
model.train()
_ = model_fa(dummy_input, **other_inputs)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ @is_flaky()
+ def test_flash_attn_3_inference_equivalence(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model.main_input_name][:1]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ dummy_attention_mask = dummy_attention_mask[:1]
+ dummy_attention_mask[:, 1:] = 1
+ dummy_attention_mask[:, :1] = 0
+
+ if model.config.is_encoder_decoder:
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
+
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ else:
+ outputs = model(dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ if model.config.is_encoder_decoder:
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+ else:
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)
+
+ # check with inference + dropout
+ model.train()
+ _ = model_fa(dummy_input, **other_inputs)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -3678,6 +3804,99 @@ def test_flash_attn_2_inference_equivalence_right_padding(self):
assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ @is_flaky()
+ def test_flash_attn_3_inference_equivalence_right_padding(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ for model_class in self.all_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model_fa = model_class.from_pretrained(
+ tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3"
+ )
+ model_fa.to(torch_device)
+
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16)
+ model.to(torch_device)
+
+ dummy_input = inputs_dict[model.main_input_name][:1]
+ if dummy_input.dtype in [torch.float32, torch.float16]:
+ dummy_input = dummy_input.to(torch.bfloat16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", None)
+
+ if dummy_attention_mask is not None:
+ dummy_attention_mask = dummy_attention_mask[:1]
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ if model.config.is_encoder_decoder:
+ decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1]
+
+ outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True)
+ else:
+ outputs = model(dummy_input, output_hidden_states=True)
+ outputs_fa = model_fa(dummy_input, output_hidden_states=True)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)
+
+ if model.config.is_encoder_decoder:
+ other_inputs = {
+ "decoder_input_ids": decoder_input_ids,
+ "decoder_attention_mask": dummy_attention_mask,
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+ else:
+ other_inputs = {
+ "output_hidden_states": True,
+ }
+ if dummy_attention_mask is not None:
+ other_inputs["attention_mask"] = dummy_attention_mask
+
+ outputs = model(dummy_input, **other_inputs)
+ outputs_fa = model_fa(dummy_input, **other_inputs)
+
+ logits = (
+ outputs.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs.decoder_hidden_states[-1]
+ )
+ logits_fa = (
+ outputs_fa.hidden_states[-1]
+ if not model.config.is_encoder_decoder
+ else outputs_fa.decoder_hidden_states[-1]
+ )
+
+ assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -3726,18 +3945,114 @@ def test_flash_attn_2_generate_left_padding(self):
self.assertTrue(torch.allclose(out, out_fa))
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ @is_flaky()
+ def test_flash_attn_3_generate_left_padding(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # make sure we do left padding
+ dummy_attention_mask[:, :-1] = 0
+ dummy_attention_mask[:, -1:] = 1
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@is_flaky()
@slow
- def test_flash_attn_2_generate_padding_right(self):
+ def test_flash_attn_2_generate_padding_right(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+ model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to(
+ torch_device
+ )
+
+ dummy_input = inputs_dict[model.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ # make sure we do right padding
+ dummy_attention_mask[:, :-1] = 1
+ dummy_attention_mask[:, -1:] = 0
+
+ out = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_2",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ out_fa = model.generate(
+ dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False
+ )
+
+ self.assertTrue(torch.allclose(out, out_fa))
+
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @is_flaky()
+ @slow
+ def test_flash_attn_3_generate_padding_right(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")
for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config)
@@ -3764,7 +4079,7 @@ def test_flash_attn_2_generate_padding_right(self):
model = model_class.from_pretrained(
tmpdirname,
torch_dtype=torch.float16,
- attn_implementation="flash_attention_2",
+ attn_implementation="flash_attention_3",
low_cpu_mem_usage=True,
).to(torch_device)
@@ -4355,6 +4670,65 @@ def test_flash_attn_2_generate_use_cache(self):
use_cache=True,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_use_cache(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_2:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # Just test that a large cache works as expected
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
+ # Generate with one batch only to test generation when attention mask will be None
+ # when real inputs are used, because there is no padding. See issue #32237 for more
+ dummy_input = dummy_input[:1, ...]
+ dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...])
+ _ = model.generate(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ )
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -4411,6 +4785,62 @@ def test_flash_attn_2_generate_reuse_cache(self):
past_key_values=past_key_values,
)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_generate_reuse_cache(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ max_new_tokens = 2
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ ).to(torch_device)
+
+ # run generate once to get filled cache
+ output = model.generate(
+ dummy_input,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ return_dict_in_generate=True,
+ )
+ past_key_values = output.past_key_values
+
+ # Try to continue generation from where we left, given that we have more than 1 new token to process
+ # e.g. this can happen in speculative decoding when feeding candidate tokens back to target model
+ dummy_input_updated = torch.cat([dummy_input, output.sequences], dim=-1)
+ _ = model.generate(
+ dummy_input_updated,
+ max_new_tokens=max_new_tokens,
+ do_sample=False,
+ use_cache=True,
+ past_key_values=past_key_values,
+ )
+
@require_flash_attn
@require_torch_gpu
@require_bitsandbytes
@@ -4468,6 +4898,63 @@ def test_flash_attn_2_fp32_ln(self):
# with attention mask
_ = model(dummy_input, attention_mask=dummy_attention_mask)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @require_bitsandbytes
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_fp32_ln(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ model = model_class(config)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ dummy_input = inputs_dict[model.main_input_name]
+ dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input))
+ batch_size = dummy_attention_mask.shape[0]
+
+ is_padding_right = dummy_attention_mask[:, -1].sum().item() != batch_size
+
+ # To avoid errors with padding_side=="right"
+ if is_padding_right:
+ dummy_attention_mask = torch.ones_like(dummy_input)
+
+ model = model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ load_in_4bit=True,
+ )
+
+ for _, param in model.named_parameters():
+ # upcast only layer norms
+ if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
+ param.data = param.data.to(torch.float32)
+
+ if model.config.is_encoder_decoder:
+ dummy_decoder_input_ids = inputs_dict["decoder_input_ids"]
+ dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"]
+
+ _ = model(dummy_input, decoder_input_ids=dummy_decoder_input_ids)
+ # with attention mask
+ _ = model(
+ dummy_input,
+ attention_mask=dummy_attention_mask,
+ decoder_input_ids=dummy_decoder_input_ids,
+ decoder_attention_mask=dummy_decoder_attention_mask,
+ )
+ else:
+ _ = model(dummy_input)
+ # with attention mask
+ _ = model(dummy_input, attention_mask=dummy_attention_mask)
+
@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@@ -4541,6 +5028,79 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, atol=tol, rtol=tol)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ max_new_tokens = 30
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
+ self.skipTest("Model dummy inputs should contain padding in their attention mask")
+
+ dummy_input = inputs_dict[model_class.main_input_name]
+ if dummy_input.dtype in [torch.float32, torch.bfloat16]:
+ dummy_input = dummy_input.to(torch.float16)
+
+ # make sure that all models have enough positions for generation
+ if hasattr(config, "max_position_embeddings"):
+ config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1
+
+ model = model_class(config)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ model.save_pretrained(tmpdirname)
+
+ # ensure left padding, to adapt for some models
+ if 0 in inputs_dict["attention_mask"][:, -1]:
+ inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
+ dummy_attention_mask = inputs_dict["attention_mask"]
+ inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.pad_token_id
+
+ model = (
+ model_class.from_pretrained(
+ tmpdirname,
+ torch_dtype=torch.float16,
+ attn_implementation="flash_attention_3",
+ low_cpu_mem_usage=True,
+ )
+ .to(torch_device)
+ .eval()
+ )
+
+ # flatten
+ padfree_inputs_dict = {
+ k: v[dummy_attention_mask.bool()].unsqueeze(0)
+ for k, v in inputs_dict.items()
+ if not k == "attention_mask"
+ }
+ # add position_ids
+ padfree_inputs_dict["position_ids"] = (
+ torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()])
+ .long()
+ .unsqueeze(0)
+ .to(torch_device)
+ )
+
+ res_padded = model(**inputs_dict)
+ res_padfree = model(**padfree_inputs_dict)
+
+ logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
+ logits_padfree = res_padfree.logits[0]
+
+ torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), atol=0, rtol=0)
+ # acceptable numerical instability
+ tol = torch.finfo(torch.float16).eps
+ torch.testing.assert_close(logits_padded, logits_padfree, atol=tol, rtol=tol)
+
@is_pt_tf_cross_test
def test_tf_from_pt_safetensors(self):
for model_class in self.all_model_classes:
@@ -4636,6 +5196,54 @@ def test_flash_attn_2_from_config(self):
self.assertFalse(fa2_correctly_converted)
+ @require_flash_attn_3
+ @require_torch_gpu
+ @mark.flash_attn_3_test
+ @slow
+ def test_flash_attn_3_from_config(self):
+ if not self.has_attentions:
+ self.skipTest(reason="Model architecture does not support attentions")
+
+ for model_class in self.all_generative_model_classes:
+ if not model_class._supports_flash_attn_3:
+ self.skipTest(f"{model_class.__name__} does not support Flash Attention 3")
+
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
+ # TODO: to change it in the future with other relevant auto classes
+ fa3_model = AutoModelForCausalLM.from_config(
+ config, attn_implementation="flash_attention_3", torch_dtype=torch.bfloat16
+ ).to(torch_device)
+
+ dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
+ dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)
+
+ fa3_correctly_converted = False
+
+ for _, module in fa3_model.named_modules():
+ if "FlashAttention" in module.__class__.__name__:
+ fa3_correctly_converted = True
+ break
+
+ self.assertTrue(fa3_correctly_converted)
+
+ _ = fa3_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ fa3_model.save_pretrained(tmpdirname)
+
+ model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname)
+
+ self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_3")
+
+ fa3_correctly_converted = False
+
+ for _, module in model_from_pretrained.named_modules():
+ if "FlashAttention" in module.__class__.__name__:
+ fa3_correctly_converted = True
+ break
+
+ self.assertFalse(fa3_correctly_converted)
+
def _get_custom_4d_mask_test_data(self):
# Sequence in which all but the last token is the same
input_ids = torch.tensor(
diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py
index 3317a47d75603c..7e2fda9a2c0f25 100644
--- a/tests/utils/test_modeling_utils.py
+++ b/tests/utils/test_modeling_utils.py
@@ -67,6 +67,7 @@
)
from transformers.utils.import_utils import (
is_flash_attn_2_available,
+ is_flash_attn_3_available,
is_flax_available,
is_tf_available,
is_torch_sdpa_available,
@@ -550,10 +551,14 @@ def test_model_from_pretrained_attn_implementation(self):
if is_flash_attn_2_available():
attn_implementation_available.append("flash_attention_2")
+ if is_flash_attn_3_available():
+ attn_implementation_available.append("flash_attention_3")
+
mistral_attention_classes = {
"eager": "MistralAttention",
"sdpa": "MistralSdpaAttention",
- "flash_attention_2": "MistralFlashAttention2",
+ "flash_attention_2": "MistralFlashAttention",
+ "flash_attention_3": "MistralFlashAttention",
}
for requested_attn_implementation in attn_implementation_available:
model = AutoModelForCausalLM.from_pretrained(
@@ -589,10 +594,14 @@ def test_model_from_config_attn_implementation(self):
if is_flash_attn_2_available():
attn_implementation_available.append("flash_attention_2")
+ if is_flash_attn_3_available():
+ attn_implementation_available.append("flash_attention_3")
+
mistral_attention_classes = {
"eager": "MistralAttention",
"sdpa": "MistralSdpaAttention",
- "flash_attention_2": "MistralFlashAttention2",
+ "flash_attention_2": "MistralFlashAttention",
+ "flash_attention_3": "MistralFlashAttention",
}
for requested_attn_implementation in attn_implementation_available:
config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation)
@@ -2520,6 +2529,14 @@ def test_error_no_flash_available(self):
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
+ def test_error_no_flash_3_available(self):
+ with self.assertRaises(ValueError) as cm:
+ _ = AutoModel.from_pretrained(
+ "hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="flash_attention_3"
+ )
+
+ self.assertTrue("does not support Flash Attention 3.0" in str(cm.exception))
+
def test_error_no_flash_available_with_config(self):
with self.assertRaises(ValueError) as cm:
config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
@@ -2530,6 +2547,16 @@ def test_error_no_flash_available_with_config(self):
self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception))
+ def test_error_no_flash_3_available_with_config(self):
+ with self.assertRaises(ValueError) as cm:
+ config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel")
+
+ _ = AutoModel.from_pretrained(
+ "hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_3"
+ )
+
+ self.assertTrue("does not support Flash Attention 3.0" in str(cm.exception))
+
def test_error_wrong_attn_implementation(self):
with self.assertRaises(ValueError) as cm:
_ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo")
@@ -2546,6 +2573,16 @@ def test_not_available_flash(self):
)
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
+ def test_not_available_flash_3(self):
+ if is_flash_attn_3_available():
+ self.skipTest(reason="Please uninstall flash_attn_interface package to run test_not_available_flash_3")
+
+ with self.assertRaises(ImportError) as cm:
+ _ = AutoModel.from_pretrained(
+ "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_3"
+ )
+ self.assertTrue("the package flash_attn_interface seems to be not installed" in str(cm.exception))
+
def test_not_available_flash_with_config(self):
if is_flash_attn_2_available():
self.skipTest(reason="Please uninstall flash-attn package to run test_not_available_flash")
@@ -2561,6 +2598,23 @@ def test_not_available_flash_with_config(self):
self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception))
+ def test_not_available_flash_3_with_config(self):
+ if is_flash_attn_3_available():
+ self.skipTest(
+ reason="Please uninstall flash_attn_interface package to run test_not_available_flash_3_with_config"
+ )
+
+ config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel")
+
+ with self.assertRaises(ImportError) as cm:
+ _ = AutoModel.from_pretrained(
+ "hf-internal-testing/tiny-random-GPTBigCodeModel",
+ config=config,
+ attn_implementation="flash_attention_3",
+ )
+
+ self.assertTrue("the package flash_attn_interface seems to be not installed" in str(cm.exception))
+
def test_not_available_sdpa(self):
if is_torch_sdpa_available():
self.skipTest(reason="This test requires torch<=2.0")