diff --git a/.github/workflows/push-important-models.yml b/.github/workflows/push-important-models.yml index 41bcd43fcc6fc2..17c1b6c86fb066 100644 --- a/.github/workflows/push-important-models.yml +++ b/.github/workflows/push-important-models.yml @@ -87,6 +87,11 @@ jobs: run: pytest -rsfE -m "flash_attn_test" --make-reports=${{ matrix.model-name }}_fa2_tests/ tests/${{ matrix.model-name }}/test_modeling_* + - name: Run FA3 tests + id: run_fa3_tests + run: + pytest -rsfE -m "flash_attn_3_test" --make-reports=${{ matrix.model-name }}_fa3_tests/ tests/${{ matrix.model-name }}/test_modeling_* + - name: "Test suite reports artifacts: ${{ matrix.model-name }}_fa2_tests" if: ${{ always() }} uses: actions/upload-artifact@v4 diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index 16be638498dfd4..4357e96735cd6d 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -348,6 +348,24 @@ model = AutoModelForCausalLM.from_pretrained( ) ``` +### FlashAttention-3 + +FlashAttention and [FlashAttention-3](./perf_infer_gpu_one#flashattention-3) break up the attention computation into smaller chunks and reduces the number of intermediate read/write operations to GPU memory to speed up inference. FlashAttention-3 improves on FlashAttention-2 algorithm by taking advantage of new features on Hopper GPUs to maximize performance. + +To use FlashAttention-3, set `attn_implementation="flash_attention_3"` in the [`~PreTrainedModel.from_pretrained`] method. + +```py +from transformers import AutoModelForCausalLM, BitsAndBytesConfig + +quant_config = BitsAndBytesConfig(load_in_8bit=True) +model = AutoModelForCausalLM.from_pretrained( + "google/gemma-2b", + quantization_config=quant_config, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_3", +) +``` + ### PyTorch scaled dot product attention Scaled dot product attention (SDPA) is automatically enabled in PyTorch 2.0 and it supports FlashAttention, xFormers, and PyTorch's C++ implementation. SDPA chooses the most performant attention algorithm if you're using a CUDA backend. For other backends, SDPA defaults to the PyTorch C++ implementation. diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index 4c220dd0f1483c..4337890e083d1f 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -199,6 +199,141 @@ FlashAttention is more memory efficient, meaning you can train on much larger se + +## FlashAttention-3 + + + +FlashAttention-3 is experimental and may change considerably in future versions. + + + +[FlashAttention-3](https://huggingface.co/papers/2407.08608) improves on FlashAttention-2 algorithm by taking advantage of new features on Hopper GPUs to maximize performance: + +1. overlap overall computation and data movement via warp-specialization +2. interleave block-wise matmul and softmax operations +3. block quantization and incoherent processing that leverages hardware support for FP8 low-precision + +FlashAttention-3 is currently supported for the following architectures: +* [Bark](https://huggingface.co/docs/transformers/model_doc/bark#transformers.BarkModel) +* [Bart](https://huggingface.co/docs/transformers/model_doc/bart#transformers.BartModel) +* [Chameleon](https://huggingface.co/docs/transformers/model_doc/chameleon#transformers.Chameleon) +* [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPModel) +* [Cohere](https://huggingface.co/docs/transformers/model_doc/cohere#transformers.CohereModel) +* [Dbrx](https://huggingface.co/docs/transformers/model_doc/dbrx#transformers.DbrxModel) +* [DistilBert](https://huggingface.co/docs/transformers/model_doc/distilbert#transformers.DistilBertModel) +* [Gemma](https://huggingface.co/docs/transformers/model_doc/gemma#transformers.GemmaModel) +* [Gemma2](https://huggingface.co/docs/transformers/model_doc/gemma2#transformers.Gemma2Model) +* [GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2) +* [GPTBigCode](https://huggingface.co/docs/transformers/model_doc/gpt_bigcode#transformers.GPTBigCodeModel) +* [GPTNeo](https://huggingface.co/docs/transformers/model_doc/gpt_neo#transformers.GPTNeoModel) +* [GPTNeoX](https://huggingface.co/docs/transformers/model_doc/gpt_neox#transformers.GPTNeoXModel) +* [GPT-J](https://huggingface.co/docs/transformers/model_doc/gptj#transformers.GPTJModel) +* [Granite](https://huggingface.co/docs/transformers/model_doc/granite#transformers.GraniteModel) +* [Idefics2](https://huggingface.co/docs/transformers/model_doc/idefics2#transformers.Idefics2Model) +* [Falcon](https://huggingface.co/docs/transformers/model_doc/falcon#transformers.FalconModel) +* [JetMoe](https://huggingface.co/docs/transformers/model_doc/jetmoe#transformers.JetMoeModel) +* [Jamba](https://huggingface.co/docs/transformers/model_doc/jamba#transformers.JambaModel) +* [Llama](https://huggingface.co/docs/transformers/model_doc/llama#transformers.LlamaModel) +* [Llava](https://huggingface.co/docs/transformers/model_doc/llava) +* [Llava-NeXT](https://huggingface.co/docs/transformers/model_doc/llava_next) +* [Llava-NeXT-Video](https://huggingface.co/docs/transformers/model_doc/llava_next_video) +* [LLaVA-Onevision](https://huggingface.co/docs/transformers/model_doc/llava_onevision) +* [Mimi](https://huggingface.co/docs/transformers/model_doc/mimi) +* [VipLlava](https://huggingface.co/docs/transformers/model_doc/vipllava) +* [VideoLlava](https://huggingface.co/docs/transformers/model_doc/video_llava) +* [M2M100](https://huggingface.co/docs/transformers/model_doc/m2m_100) +* [MBart](https://huggingface.co/docs/transformers/model_doc/mbart#transformers.MBartModel) +* [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) +* [Mixtral](https://huggingface.co/docs/transformers/model_doc/mixtral#transformers.MixtralModel) +* [Musicgen](https://huggingface.co/docs/transformers/model_doc/musicgen#transformers.MusicgenModel) +* [MusicGen Melody](https://huggingface.co/docs/transformers/model_doc/musicgen_melody#transformers.MusicgenMelodyModel) +* [Nemotron](https://huggingface.co/docs/transformers/model_doc/nemotron) +* [NLLB](https://huggingface.co/docs/transformers/model_doc/nllb) +* [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) +* [OLMoE](https://huggingface.co/docs/transformers/model_doc/olmoe#transformers.OlmoeModel) +* [OPT](https://huggingface.co/docs/transformers/model_doc/opt#transformers.OPTModel) +* [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) +* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model) +* [StableLm](https://huggingface.co/docs/transformers/model_doc/stablelm#transformers.StableLmModel) +* [Starcoder2](https://huggingface.co/docs/transformers/model_doc/starcoder2#transformers.Starcoder2Model) +* [Qwen2](https://huggingface.co/docs/transformers/model_doc/qwen2#transformers.Qwen2Model) +* [Qwen2Audio](https://huggingface.co/docs/transformers/model_doc/qwen2_audio#transformers.Qwen2AudioEncoder) +* [Qwen2MoE](https://huggingface.co/docs/transformers/model_doc/qwen2_moe#transformers.Qwen2MoeModel) +* [Qwen2VL](https://huggingface.co/docs/transformers/model_doc/qwen2_vl#transformers.Qwen2VLModel) +* [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) +* [Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2#transformers.Wav2Vec2Model) +* [Hubert](https://huggingface.co/docs/transformers/model_doc/hubert#transformers.HubertModel) +* [data2vec_audio](https://huggingface.co/docs/transformers/main/en/model_doc/data2vec#transformers.Data2VecAudioModel) +* [Sew](https://huggingface.co/docs/transformers/main/en/model_doc/sew#transformers.SEWModel) +* [SigLIP](https://huggingface.co/docs/transformers/model_doc/siglip) +* [UniSpeech](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech#transformers.UniSpeechModel) +* [unispeech_sat](https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/unispeech-sat#transformers.UniSpeechSatModel) + +You can request to add FlashAttention-3 support for another model by opening a GitHub Issue or Pull Request. + +Before you begin, make sure you have FlashAttention-3 installed. + + + + +```bash +git clone https://github.com/Dao-AILab/flash-attention +cd flash-attention/hopper +python setup.py install +``` + + + + +To enable FlashAttention-3, pass the argument `attn_implementation="flash_attention_3"` to [`~AutoModelForCausalLM.from_pretrained`]: + +```python +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM + +model_id = "tiiuae/falcon-7b" +tokenizer = AutoTokenizer.from_pretrained(model_id) + +model = AutoModelForCausalLM.from_pretrained( + model_id, + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_3", +) +``` + + + +FlashAttention-3 can only be used when the model's dtype is `fp16` or `bf16`. Make sure to cast your model to the appropriate dtype and load them on a supported device before using FlashAttention-3. + +
+ +
+ +FlashAttention-3 can be combined with other optimization techniques like quantization to further speedup inference. For example, you can combine FlashAttention-3 with 8-bit or 4-bit quantization: + +```py +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM + +model_id = "tiiuae/falcon-7b" +tokenizer = AutoTokenizer.from_pretrained(model_id) + +# load in 8bit +model = AutoModelForCausalLM.from_pretrained( + model_id, + load_in_8bit=True, + attn_implementation="flash_attention_3", +) + +# load in 4bit +model = AutoModelForCausalLM.from_pretrained( + model_id, + load_in_4bit=True, + attn_implementation="flash_attention_3", +) +``` + ## PyTorch scaled dot product attention PyTorch's [`torch.nn.functional.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html) (SDPA) can also call FlashAttention and memory-efficient attention kernels under the hood. SDPA support is currently being added natively in Transformers and is used by default for `torch>=2.1.1` when an implementation is available. You may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. diff --git a/pyproject.toml b/pyproject.toml index bf78e0174394f5..fdcccd96884caa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ addopts = "--doctest-glob='**/*.md'" doctest_optionflags="NUMBER NORMALIZE_WHITESPACE ELLIPSIS" markers = [ "flash_attn_test: marks tests related to flash attention (deselect with '-m \"not flash_attn_test\"')", + "flash_attn_3_test: marks tests related to flash attention 3 (deselect with '-m \"not flash_attn_3_test\"')", "bitsandbytes: select (or deselect with `not`) bitsandbytes integration tests", "generate: marks tests that use the GenerationTesterMixin" ] diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index aa13a97fe46150..9aac69fab90d05 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1334,6 +1334,7 @@ "convert_and_export_with_cache", ] + _import_structure["modeling_flash_attention_3_utils"] = [] _import_structure["modeling_flash_attention_utils"] = [] _import_structure["modeling_outputs"] = [] _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"] diff --git a/src/transformers/modeling_flash_attention_3_utils.py b/src/transformers/modeling_flash_attention_3_utils.py new file mode 100644 index 00000000000000..0a3faf784a06af --- /dev/null +++ b/src/transformers/modeling_flash_attention_3_utils.py @@ -0,0 +1,295 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + +from .utils import is_flash_attn_3_available + + +if is_flash_attn_3_available(): + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from flash_attn_interface import _flash_attn_forward, flash_attn_func, flash_attn_varlen_func + + +def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def prepare_fa2_from_position_ids(query, key, value, position_ids): + """ + This function returns necessary arguments to call `flash_attn_varlen_func`. + All three query, key, value states will be flattened. + Cummulative lengths of each examples in the batch will be extracted from position_ids. + + NOTE: ideally cummulative lengths should be prepared at the data collator stage + + Arguments: + query (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + query (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.view(-1, key.size(-2), key.size(-1)) + value = value.view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + + cu_seq_lens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) + + max_length = position_ids.max() + 1 + + return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + + +def _flash_attention_3_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool, + position_ids: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + use_top_left_mask: bool = False, + deterministic: bool = None, + descale: float = 1.0, + use_fp8: bool = False, +): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_top_left_mask (`bool`, defaults to `False`): + flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. + deterministic (`bool`, *optional*): + Determines if the deterministic option enabled. + """ + use_fp8 = os.environ.get("FLASH_ATTENTION_3_FP8", "0") == "1" + + softmax_scale = softmax_scale or query_states.shape[-1] ** (-0.5) + + if not use_top_left_mask: + causal = is_causal + else: + causal = is_causal and query_length != 1 + + flash_kwargs = {} + + if deterministic is None: + deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + flash_kwargs["deterministic"] = deterministic + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad, _ = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + + # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing + # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. + # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach + elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1: + batch_size = query_states.size(0) + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output, _ = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + + else: + if use_fp8: + # NOTE: descale? + attn_output = _flash_attn_forward( + query_states.to(torch.float8_e4m3fn), + key_states.to(torch.float8_e4m3fn), + value_states.to(torch.float8_e4m3fn), + softmax_scale=softmax_scale, + causal=causal, + )[0] + else: + attn_output, _ = flash_attn_func( + query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs + ) + + return attn_output diff --git a/src/transformers/modeling_flash_attention_utils.py b/src/transformers/modeling_flash_attention_utils.py index 44e61825dd9cd6..3096e121022c59 100644 --- a/src/transformers/modeling_flash_attention_utils.py +++ b/src/transformers/modeling_flash_attention_utils.py @@ -223,7 +223,7 @@ def _flash_attention_forward( if not use_top_left_mask: causal = is_causal else: - # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__. + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__. causal = is_causal and query_length != 1 # Assuming 4D tensors, key_states.shape[1] is the key/value sequence length (source length). diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d4069766636053..314e3ee22cafdf 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -82,6 +82,7 @@ is_accelerate_available, is_bitsandbytes_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_offline_mode, is_optimum_available, is_peft_available, @@ -1367,6 +1368,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Flash Attention 2 support _supports_flash_attn_2 = False + # Flash Attention 3 support + _supports_flash_attn_3 = False + # SDPA support _supports_sdpa = False @@ -1550,10 +1554,12 @@ def _autoset_attn_implementation( ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' ) - if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]: + if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2", "flash_attention_3"]: message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' if cls._supports_flash_attn_2: message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' + if cls._supports_flash_attn_3: + message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)' if cls._supports_sdpa: message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' raise ValueError(message + ".") @@ -1575,6 +1581,14 @@ def _autoset_attn_implementation( hard_check_only=False, check_device_map=check_device_map, ) + elif config._attn_implementation == "flash_attention_3": + cls._check_and_enable_flash_attn_3( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. config = cls._check_and_enable_sdpa( @@ -1734,6 +1748,83 @@ def _check_and_enable_flash_attn_2( config._attn_implementation = "flash_attention_2" return config + @classmethod + def _check_and_enable_flash_attn_3( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + hard_check_only: bool = False, + ) -> PretrainedConfig: + """ + Checks the availability of Flash Attention 3 and compatibility with the current model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module. + """ + if not cls._supports_flash_attn_3: + raise ValueError( + f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where" + f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new" + " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" + ) + + if not is_flash_attn_3_available(): + preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:" + # TODO: docs + install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-3 to install Flash Attention 3." + + if importlib.util.find_spec("flash_attn_interface") is None: + raise ImportError( + f"{preface} the package flash_attn_interface seems to be not installed. {install_message}" + ) + + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + + if _is_bettertransformer: + raise ValueError( + "Flash Attention 3 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" + ) + + if torch_dtype is None: + logger.warning_once( + "You are attempting to use Flash Attention 3.0 without specifying a torch dtype. This might lead to unexpected behaviour" + ) + elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: + logger.warning_once( + "Flash Attention 3.0 only supports torch.float16 and torch.bfloat16 dtypes, but" + f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," + ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`' + ) + + # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, + # or the model may be initialized under the context manager `with torch.device("cuda"):`. + if check_device_map and device_map is None and torch.empty(0).device.type != "cuda": + if torch.cuda.is_available(): + logger.warning_once( + "You are attempting to use Flash Attention 3.0 with a model not initialized on GPU. Make sure to move the model to GPU" + " after initializing it on CPU with `model.to('cuda')`." + ) + else: + raise ValueError( + "You are attempting to use Flash Attention 3.0 with a model not initialized on GPU and with no GPU available. " + "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " + "or initialising the model on CPU and then moving it to GPU." + ) + elif ( + check_device_map + and device_map is not None + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + raise ValueError( + "You are attempting to use Flash Attention 3.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " + "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." + ) + if not hard_check_only: + config._attn_implementation = "flash_attention_3" + return config + @classmethod def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: """ diff --git a/src/transformers/models/altclip/modeling_altclip.py b/src/transformers/models/altclip/modeling_altclip.py index 4ed0930605e899..4e6010096daa63 100755 --- a/src/transformers/models/altclip/modeling_altclip.py +++ b/src/transformers/models/altclip/modeling_altclip.py @@ -721,7 +721,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class AltCLIPAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: AltCLIPConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 5aad7b23a8a672..114bfa143b2b0b 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -35,6 +35,7 @@ add_start_docstrings_to_model_forward, is_accelerate_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, ) @@ -56,6 +57,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -68,9 +72,9 @@ class BarkSelfAttention(nn.Module): # adapted from GPTNeoSelfAttention and Bark code # BarkSelfAttention can have two attention type, i.e full attention or causal attention - def __init__(self, config, is_causal=False): + def __init__(self, config: BarkConfig, is_causal=False): super().__init__() - + self.config = config # regularization self.dropout = config.dropout self.attn_dropout = nn.Dropout(config.dropout) @@ -189,14 +193,14 @@ def forward( return outputs -class BarkSelfFlashAttention2(BarkSelfAttention): +class BarkSelfFlashAttention(BarkSelfAttention): """ Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -204,6 +208,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def _split_heads(self, tensor, num_heads, attn_head_size): """ @@ -256,16 +261,26 @@ def forward( else: present = None - attn_output = _flash_attention_forward( - query, - key, - value, - attention_mask, - query_len, - dropout=self.dropout, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query, + key, + value, + attention_mask, + query_len, + dropout=self.dropout, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.out_proj(attn_output) @@ -281,7 +296,8 @@ def forward( BARK_ATTENTION_CLASSES = { "eager": BarkSelfAttention, - "flash_attention_2": BarkSelfFlashAttention2, + "flash_attention_2": BarkSelfFlashAttention, + "flash_attention_3": BarkSelfFlashAttention, } @@ -376,6 +392,7 @@ class BarkPreTrainedModel(PreTrainedModel): config_class = BarkConfig supports_gradient_checkpointing = False _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module): """Initialize the weights.""" @@ -561,6 +578,7 @@ def __init__(self, config): self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias) @@ -702,7 +720,7 @@ def forward( if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None else: attention_mask = attention_mask.view(batch_size, -1) @@ -1157,6 +1175,7 @@ def __init__(self, config): self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self.layernorm_final = nn.LayerNorm(config.hidden_size) @@ -1341,7 +1360,7 @@ def forward( if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None else: # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length] @@ -1811,3 +1830,39 @@ def _check_and_enable_flash_attn_2( config.coarse_acoustics_config._attn_implementation = config._attn_implementation config.fine_acoustics_config._attn_implementation = config._attn_implementation return config + + @classmethod + def _check_and_enable_flash_attn_3( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + hard_check_only: bool = False, + check_device_map: bool = False, + ): + """ + `_check_and_enable_flash_attn_3` originally don't expand flash attention enabling to the model + sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention + if necessary. + + If you don't know about Flash Attention, check out the official repository of flash attention: + https://github.com/Dao-AILab/flash-attention + + For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this + specific section of the documentation to learn more about it: + https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models + + The method checks if the current setup is compatible with Flash Attention as it requires the model to be in + half precision and not ran on CPU. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_3" so that the model + can initialize the correct attention module + """ + config = super()._check_and_enable_flash_attn_3( + config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map + ) + + config.semantic_config._attn_implementation = config._attn_implementation + config.coarse_acoustics_config._attn_implementation = config._attn_implementation + config.fine_acoustics_config._attn_implementation = config._attn_implementation + return config diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index fa928d05caa89b..fb04b4ec9b7d8e 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -47,6 +47,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -57,6 +58,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -286,14 +289,14 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -class BartFlashAttention2(BartAttention): +class BartFlashAttention(BartAttention): """ Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -301,6 +304,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -393,16 +397,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) @@ -522,7 +536,8 @@ def forward( BART_ATTENTION_CLASSES = { "eager": BartAttention, "sdpa": BartSdpaAttention, - "flash_attention_2": BartFlashAttention2, + "flash_attention_2": BartFlashAttention, + "flash_attention_3": BartFlashAttention, } @@ -748,6 +763,7 @@ class BartPreTrainedModel(PreTrainedModel): _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): @@ -979,6 +995,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No ) self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) @@ -1067,7 +1084,7 @@ def forward( # expand attention_mask if attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None elif self._use_sdpa and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA, fall back to @@ -1163,6 +1180,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No ) self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -1283,7 +1301,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: @@ -1303,7 +1321,7 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index b5b221b6b37f83..1c1b604a2947a8 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -794,7 +794,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 23334311ca9511..09d912d9e9a3b7 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -419,9 +420,9 @@ def forward( return attn_output, attn_weights, past_key_value -# copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Chameleon +# copied from transformers.models.llama.modeling_llama.LlamaFlashAttention with Llama->Chameleon # TODO(joao): add me back asap :) -class ChameleonFlashAttention2(ChameleonAttention): +class ChameleonFlashAttention(ChameleonAttention): """ Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -435,6 +436,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" # Ignore copy def forward( @@ -517,17 +519,27 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -633,7 +645,8 @@ def forward( CHAMELEON_ATTENTION_CLASSES = { "eager": ChameleonAttention, - "flash_attention_2": ChameleonFlashAttention2, + "flash_attention_2": ChameleonFlashAttention, + "flash_attention_3": ChameleonFlashAttention, "sdpa": ChameleonSdpaAttention, } @@ -1155,6 +1168,7 @@ class ChameleonPreTrainedModel(PreTrainedModel): _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"] _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_quantized_cache = True _supports_cache_class = True @@ -1433,7 +1447,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 64eb027e9e220c..fc90ef77504ea6 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -33,6 +33,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -43,6 +44,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -244,7 +247,7 @@ def forward( class CLIPAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: CLIPConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -345,14 +348,14 @@ def forward( return attn_output, attn_weights_reshaped -class CLIPFlashAttention2(CLIPAttention): +class CLIPFlashAttention(CLIPAttention): """ CLIPAttention flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -360,8 +363,9 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" - # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward + # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention.forward def forward( self, hidden_states: torch.Tensor, @@ -412,16 +416,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - is_causal=causal_attention_mask is not None, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=causal_attention_mask is not None, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + is_causal=causal_attention_mask is not None, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() attn_output = self.out_proj(attn_output) @@ -508,7 +522,8 @@ def forward( CLIP_ATTENTION_CLASSES = { "eager": CLIPAttention, "sdpa": CLIPSdpaAttention, - "flash_attention_2": CLIPFlashAttention2, + "flash_attention_2": CLIPFlashAttention, + "flash_attention_3": CLIPFlashAttention, } @@ -588,6 +603,7 @@ class CLIPPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module): """Initialize the weights""" @@ -857,6 +873,7 @@ def __init__(self, config: CLIPTextConfig): # For attention mask, it differs between `flash_attention_2` and other attention implementations self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) @@ -894,7 +911,7 @@ def forward( ) # expand attention_mask - if attention_mask is not None and not self._use_flash_attention_2: + if attention_mask is not None and not self._use_flash_attention_2 and not self._use_flash_attention_3: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) diff --git a/src/transformers/models/clipseg/modeling_clipseg.py b/src/transformers/models/clipseg/modeling_clipseg.py index a6507e431f68e2..d40acbcfc53a40 100644 --- a/src/transformers/models/clipseg/modeling_clipseg.py +++ b/src/transformers/models/clipseg/modeling_clipseg.py @@ -238,7 +238,7 @@ def forward( class CLIPSegAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: CLIPSegConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index be57838975c0b0..e8ea79da9ced84 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -637,7 +637,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index cb1b3f885798c8..3336cd42294591 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -44,6 +44,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -55,6 +56,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -416,8 +420,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->Cohere -class CohereFlashAttention2(CohereAttention): +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention with Llama->Cohere +class CohereFlashAttention(CohereAttention): """ Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -431,6 +435,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" # Ignore copy def forward( @@ -519,16 +524,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -644,7 +659,8 @@ def forward( COHERE_ATTENTION_CLASSES = { "eager": CohereAttention, - "flash_attention_2": CohereFlashAttention2, + "flash_attention_2": CohereFlashAttention, + "flash_attention_3": CohereFlashAttention, "sdpa": CohereSdpaAttention, } @@ -751,6 +767,7 @@ class CoherePreTrainedModel(PreTrainedModel): _no_split_modules = ["CohereDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1008,7 +1025,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index dd2a676b26c27f..4d36f4b1bcef98 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -40,6 +40,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_peft_available, logging, @@ -50,6 +51,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -480,15 +484,15 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Data2VecAudio -class Data2VecAudioFlashAttention2(Data2VecAudioAttention): +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->Data2VecAudio +class Data2VecAudioFlashAttention(Data2VecAudioAttention): """ Data2VecAudio flash attention module. This module inherits from `Data2VecAudioAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -496,6 +500,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -588,16 +593,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) @@ -718,7 +733,8 @@ def forward( DATA2VEC2AUDIO_ATTENTION_CLASSES = { "eager": Data2VecAudioAttention, "sdpa": Data2VecAudioSdpaAttention, - "flash_attention_2": Data2VecAudioFlashAttention2, + "flash_attention_2": Data2VecAudioFlashAttention, + "flash_attention_3": Data2VecAudioFlashAttention, } @@ -794,6 +810,7 @@ def __init__(self, config): self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -810,7 +827,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -931,6 +948,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 7263713c084007..bf8f97d2503531 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -30,6 +30,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -40,6 +41,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DbrxConfig" @@ -363,7 +367,7 @@ def forward( return attn_output, attn_weights, past_key_value -class DbrxFlashAttention2(DbrxAttention): +class DbrxFlashAttention(DbrxAttention): """Dbrx flash attention module. This module inherits from `DbrxAttention` as the weights of the module stays @@ -371,7 +375,7 @@ class DbrxFlashAttention2(DbrxAttention): calls the public API of flash attention. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -379,6 +383,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -463,17 +468,28 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.out_proj(attn_output) @@ -581,7 +597,8 @@ def forward( DBRX_ATTENTION_CLASSES = { "eager": DbrxAttention, - "flash_attention_2": DbrxFlashAttention2, + "flash_attention_2": DbrxFlashAttention, + "flash_attention_3": DbrxFlashAttention, "sdpa": DbrxSdpaAttention, } @@ -882,6 +899,7 @@ class DbrxPreTrainedModel(PreTrainedModel): _no_split_modules = ["DbrxBlock"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1167,7 +1185,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index b8eb9f5a8b4222..47c921100c2c94 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -100,9 +100,9 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): return model -# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2->DecisionTransformerGPT2 +# Copied from transformers.models.gpt2.modeling_gpt2.GPT2Attention with GPT2Config->DecisionTransformerConfig,GPT2->DecisionTransformerGPT2 class DecisionTransformerGPT2Attention(nn.Module): - def __init__(self, config, is_cross_attention=False, layer_idx=None): + def __init__(self, config: DecisionTransformerConfig, is_cross_attention=False, layer_idx=None): super().__init__() self.config = config max_positions = config.max_position_embeddings diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index e80e3c41d22cb6..8a525b357ae4d6 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -44,6 +44,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -54,6 +55,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "distilbert-base-uncased" @@ -232,14 +235,14 @@ def unshape(x: torch.Tensor) -> torch.Tensor: return (context,) -class DistilBertFlashAttention2(MultiHeadSelfAttention): +class DistilBertFlashAttention(MultiHeadSelfAttention): """ DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -247,6 +250,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -309,16 +313,26 @@ def reshape(x: torch.Tensor) -> torch.Tensor: key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_weights = _flash_attention_forward( - query_states, - key_states, - value_states, - mask, - q_length, - dropout=attn_dropout, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_weights = _flash_attention_3_forward( + query_states, + key_states, + value_states, + mask, + q_length, + is_causal=self.is_causal, + ) + else: + attn_weights = _flash_attention_forward( + query_states, + key_states, + value_states, + mask, + q_length, + dropout=attn_dropout, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head) attn_output = self.out_lin(attn_weights_reshaped) @@ -352,7 +366,8 @@ def ff_chunk(self, input: torch.Tensor) -> torch.Tensor: DISTILBERT_ATTENTION_CLASSES = { "eager": MultiHeadSelfAttention, - "flash_attention_2": DistilBertFlashAttention2, + "flash_attention_2": DistilBertFlashAttention, + "flash_attention_3": DistilBertFlashAttention, } @@ -503,6 +518,7 @@ class DistilBertPreTrainedModel(PreTrainedModel): base_model_prefix = "distilbert" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module: nn.Module): """Initialize the weights.""" @@ -589,6 +605,7 @@ def __init__(self, config: PretrainedConfig): self.embeddings = Embeddings(config) # Embeddings self.transformer = Transformer(config) # Encoder self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" # Initialize weights and apply final processing self.post_init() @@ -694,7 +711,7 @@ def forward( embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: if attention_mask is None: diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 9a37fe22e1779e..524094c8ab9c64 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -43,6 +43,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, ) @@ -55,6 +56,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -574,14 +578,14 @@ def forward( return attn_output, layer_past -class FalconFlashAttention2(FalconAttention): +class FalconFlashAttention(FalconAttention): """ Falcon flash attention module. This module inherits from `FalconAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -589,6 +593,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -667,17 +672,28 @@ def forward( key_layer = key_layer.to(target_dtype) value_layer = value_layer.to(target_dtype) - attn_output = _flash_attention_forward( - query_layer, - key_layer, - value_layer, - attention_mask, - query_length, - position_ids=position_ids, - dropout=attn_dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_layer, + key_layer, + value_layer, + attention_mask, + query_length, + position_ids=position_ids, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_layer, + key_layer, + value_layer, + attention_mask, + query_length, + position_ids=position_ids, + dropout=attn_dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) attn_output = self.dense(attn_weights) @@ -707,7 +723,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: FALCON_ATTENTION_CLASSES = { "eager": FalconAttention, "sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA - "flash_attention_2": FalconFlashAttention2, + "flash_attention_2": FalconFlashAttention, + "flash_attention_3": FalconFlashAttention, } @@ -909,6 +926,7 @@ class FalconPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["FalconDecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -971,6 +989,7 @@ def __init__(self, config: FalconConfig): # Transformer blocks self.h = nn.ModuleList([FalconDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" # Final Layer Norm @@ -1161,7 +1180,10 @@ def _update_causal_mask( # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gemma/diff_gemma.py b/src/transformers/models/gemma/diff_gemma.py index 36f2d1c594abaa..5bb934b7feed0f 100644 --- a/src/transformers/models/gemma/diff_gemma.py +++ b/src/transformers/models/gemma/diff_gemma.py @@ -23,7 +23,7 @@ from transformers import PretrainedConfig from transformers.models.llama.modeling_llama import ( - LlamaFlashAttention2, + LlamaFlashAttention, LlamaForCausalLM, LlamaForSequenceClassification, LlamaForTokenClassification, @@ -34,6 +34,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import CausalLMOutputWithPast from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -340,7 +341,7 @@ def forward( # TODO felix: does this inheritance really work out in the end to GemmaFlashAttention2 inheriting form GemmaAttention? -class GemmaFlashAttention2(LlamaFlashAttention2): +class GemmaFlashAttention(LlamaFlashAttention): """ Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -420,17 +421,27 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index dd4c899d13d465..24292abc5d4bff 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -357,7 +358,7 @@ def forward( return attn_output, attn_weights, past_key_value -class GemmaFlashAttention2(GemmaAttention): +class GemmaFlashAttention(GemmaAttention): """ Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -371,6 +372,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -445,18 +447,29 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -557,7 +570,8 @@ def forward( GEMMA_ATTENTION_CLASSES = { "eager": GemmaAttention, - "flash_attention_2": GemmaFlashAttention2, + "flash_attention_2": GemmaFlashAttention, + "flash_attention_3": GemmaFlashAttention, "sdpa": GemmaSdpaAttention, } @@ -665,6 +679,7 @@ class GemmaPreTrainedModel(PreTrainedModel): _no_split_modules = ["GemmaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -930,7 +945,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gemma2/diff_gemma2.py b/src/transformers/models/gemma2/diff_gemma2.py index 30f371a1b61267..3e326709b7a558 100644 --- a/src/transformers/models/gemma2/diff_gemma2.py +++ b/src/transformers/models/gemma2/diff_gemma2.py @@ -34,12 +34,15 @@ from ...cache_utils import Cache from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging +from ...utils import is_flash_attn_2_available, is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -73,7 +76,7 @@ def __init__(self, config: Gemma2Config, layer_idx: Optional[int] = None): self.scaling = config.query_pre_attn_scalar**-0.5 -class Gemma2FlashAttention2(Gemma2Attention): +class Gemma2FlashAttention(Gemma2Attention): """ Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -87,6 +90,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -156,17 +160,28 @@ def forward( value_states = value_states.to(target_dtype) ########### ONLY DIFFERENCE IS WE USE SLIDING AND PASS THE SOFTMAX SCALING - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=self.scaling, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + softmax_scale=self.scaling, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.scaling, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -439,7 +454,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index be964c9aed018a..120ab8144ef71d 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -39,6 +39,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, @@ -51,6 +52,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -327,7 +331,7 @@ def forward( return attn_output, attn_weights, past_key_value -class Gemma2FlashAttention2(Gemma2Attention): +class Gemma2FlashAttention(Gemma2Attention): """ Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -341,6 +345,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -419,19 +424,30 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - softmax_scale=self.scaling, - is_causal=self.is_causal, - sliding_window=self.sliding_window, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + softmax_scale=self.scaling, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.scaling, + is_causal=self.is_causal, + sliding_window=self.sliding_window, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + softcap=self.config.attn_logit_softcapping if is_flash_attn_greater_or_equal("2.6.0") else None, + ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -534,7 +550,8 @@ def forward( GEMMA2_ATTENTION_CLASSES = { "eager": Gemma2Attention, - "flash_attention_2": Gemma2FlashAttention2, + "flash_attention_2": Gemma2FlashAttention, + "flash_attention_3": Gemma2FlashAttention, "sdpa": Gemma2SdpaAttention, } @@ -568,7 +585,10 @@ def forward( ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding # Flash-attn is a 2D tensor - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if past_key_value is not None: # when decoding attention_mask = attention_mask[:, -self.sliding_window :] else: @@ -641,6 +661,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Gemma2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = False @@ -906,7 +927,10 @@ def _update_causal_mask( # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible # as it doesn't cause dynamic control issues. - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): return attention_mask dtype, device = input_tensor.dtype, input_tensor.device @@ -1117,6 +1141,7 @@ def prepare_inputs_for_generation( isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2 and not self.config._attn_implementation == "flash_attention_2" + and not self.config._attn_implementation == "flash_attention_3" ): if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 7471eec811a065..01f37fbbe763d7 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -662,7 +662,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GitVisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: GitVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 8dfbfb9064444d..cd5980af02f652 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -45,6 +45,7 @@ add_start_docstrings_to_model_forward, get_torch_version, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -56,6 +57,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -120,7 +123,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): class GPT2Attention(nn.Module): - def __init__(self, config, is_cross_attention=False, layer_idx=None): + def __init__(self, config: GPT2Config, is_cross_attention=False, layer_idx=None): super().__init__() self.config = config max_positions = config.max_position_embeddings @@ -341,14 +344,14 @@ def forward( return outputs # a, present, (attentions) -class GPT2FlashAttention2(GPT2Attention): +class GPT2FlashAttention(GPT2Attention): """ GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -356,6 +359,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -432,16 +436,26 @@ def forward( key = key.to(target_dtype) value = value.to(target_dtype) - attn_output = _flash_attention_forward( - query, - key, - value, - attention_mask, - query_length, - dropout=attn_dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_length, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query, + key, + value, + attention_mask, + query_length, + dropout=attn_dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) attn_output = self.c_proj(attn_weights_reshaped) @@ -578,7 +592,12 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl return hidden_states -GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention} +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention, + "flash_attention_3": GPT2FlashAttention, + "sdpa": GPT2SdpaAttention, +} class GPT2Block(nn.Module): @@ -674,6 +693,7 @@ class GPT2PreTrainedModel(PreTrainedModel): _no_split_modules = ["GPT2Block"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__(self, *inputs, **kwargs): @@ -1031,7 +1051,7 @@ def forward( # Attention mask. _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None - if self._attn_implementation == "flash_attention_2": + if self._attn_implementation == "flash_attention_2" or self._attn_implementation == "flash_attention_3": attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif _use_sdpa: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -1068,7 +1088,10 @@ def forward( encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1] ) - elif not self._attn_implementation == "flash_attention_2": + elif ( + not self._attn_implementation == "flash_attention_2" + and not self._attn_implementation == "flash_attention_3" + ): encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_attention_mask = None diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 0f927a72469dc9..fcf3ce05edd8cb 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -36,6 +36,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, ) @@ -45,6 +46,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -83,7 +87,7 @@ def masked_softmax(x: torch.Tensor, mask: torch.Tensor, mask_value: torch.Tensor class GPTBigCodeAttention(nn.Module): - def __init__(self, config, is_cross_attention=False, layer_idx=None): + def __init__(self, config: GPTBigCodeConfig, is_cross_attention=False, layer_idx=None): super().__init__() self.config = config @@ -270,14 +274,14 @@ def forward( return outputs # a, present, (attentions) -class GPTBigCodeFlashAttention2(GPTBigCodeAttention): +class GPTBigCodeFlashAttention(GPTBigCodeAttention): """ GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -285,6 +289,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -367,16 +372,26 @@ def forward( key = key.to(target_dtype) value = value.to(target_dtype) - attn_output = _flash_attention_forward( - query, - key, - value, - attention_mask, - query_length, - dropout=attn_dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_length, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query, + key, + value, + attention_mask, + query_length, + dropout=attn_dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) attn_output = self.c_proj(attn_weights_reshaped) @@ -560,7 +575,8 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl GPTBIGCODE_ATTENTION_CLASSES = { "eager": GPTBigCodeAttention, - "flash_attention_2": GPTBigCodeFlashAttention2, + "flash_attention_2": GPTBigCodeFlashAttention, + "flash_attention_3": GPTBigCodeFlashAttention, "sdpa": GPTBigCodeSdpaAttention, } @@ -666,6 +682,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTBigCodeBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__(self, *inputs, **kwargs): @@ -810,6 +827,7 @@ def __init__(self, config): self._use_sdpa = config._attn_implementation == "sdpa" self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" # Initialize weights and apply final processing self.post_init() @@ -891,7 +909,7 @@ def forward( key_length = past_length + query_length self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None encoder_attention_mask = ( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 28309f7738ebe2..6ee3ac9c73efcd 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -41,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torch_fx_available, logging, @@ -51,6 +52,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. @@ -204,7 +207,7 @@ def load_tf_weights_in_gpt_neo(model, config, gpt_neo_checkpoint_path): class GPTNeoSelfAttention(nn.Module): - def __init__(self, config, attention_type, layer_id=None): + def __init__(self, config: GPTNeoConfig, attention_type, layer_id=None): super().__init__() self.config = config @@ -324,14 +327,14 @@ def forward( return outputs # a, past_kv, (attentions) -class GPTNeoFlashAttention2(GPTNeoSelfAttention): +class GPTNeoFlashAttention(GPTNeoSelfAttention): """ GPTNeo flash attention module. This module inherits from `GPTNeoSelfAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -339,6 +342,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -403,17 +407,28 @@ def forward( key = key.to(target_dtype) value = value.to(target_dtype) - attn_output = _flash_attention_forward( - query, - key, - value, - attention_mask, - query_length, - dropout=attn_dropout, - softmax_scale=1.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_length, + softmax_scale=1.0, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query, + key, + value, + attention_mask, + query_length, + dropout=attn_dropout, + softmax_scale=1.0, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) attn_output = self.out_proj(attn_weights_reshaped) @@ -428,7 +443,8 @@ def forward( GPT_NEO_ATTENTION_CLASSES = { "eager": GPTNeoSelfAttention, - "flash_attention_2": GPTNeoFlashAttention2, + "flash_attention_2": GPTNeoFlashAttention, + "flash_attention_3": GPTNeoFlashAttention, } @@ -550,6 +566,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTNeoBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = False # TODO: needs a HybridCache @@ -851,7 +868,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 274c571fa8939c..6ed39200cf3946 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -43,6 +43,7 @@ from ...utils import ( get_torch_version, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, ) @@ -52,6 +53,10 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM" @@ -125,6 +130,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTNeoXLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True @@ -146,7 +152,7 @@ def _init_weights(self, module): class GPTNeoXAttention(nn.Module): - def __init__(self, config, layer_idx=None): + def __init__(self, config: GPTNeoXConfig, layer_idx=None): super().__init__() self.config = config self.num_attention_heads = config.num_attention_heads @@ -355,14 +361,14 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None): return attn_output, attn_weights -class GPTNeoXFlashAttention2(GPTNeoXAttention): +class GPTNeoXFlashAttention(GPTNeoXAttention): """ GPTNeoX flash attention module. This module inherits from `GPTNeoXAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -370,6 +376,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -434,17 +441,28 @@ def forward( attention_dropout = self.config.attention_dropout if self.training else 0.0 # Compute attention - attn_weights = _flash_attention_forward( - query, - key, - value, - attention_mask, - query_length, - dropout=attention_dropout, - softmax_scale=self.norm_factor, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_weights = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_length, + softmax_scale=self.norm_factor, + is_causal=self.is_causal, + ) + else: + attn_weights = _flash_attention_forward( + query, + key, + value, + attention_mask, + query_length, + dropout=attention_dropout, + softmax_scale=self.norm_factor, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) # Reshape outputs attn_output = attn_weights.reshape( @@ -726,7 +744,8 @@ def forward(self, hidden_states): GPT_NEOX_ATTENTION_CLASSES = { "eager": GPTNeoXAttention, - "flash_attention_2": GPTNeoXFlashAttention2, + "flash_attention_2": GPTNeoXFlashAttention, + "flash_attention_3": GPTNeoXFlashAttention, "sdpa": GPTNeoXSdpaAttention, } @@ -1048,7 +1067,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 048e108a8ec2d7..2999c32a42ca09 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -752,7 +752,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 84f6d985f76474..a17aece9548f60 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -38,6 +38,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torch_fx_proxy, logging, @@ -50,6 +51,10 @@ from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "hf-internal-testing/tiny-random-gptj" @@ -136,7 +141,7 @@ def apply_rotary_pos_emb(tensor: torch.Tensor, sin: torch.Tensor, cos: torch.Ten class GPTJAttention(nn.Module): - def __init__(self, config, layer_idx=None): + def __init__(self, config: GPTJConfig, layer_idx=None): super().__init__() self.config = config max_positions = config.max_position_embeddings @@ -312,14 +317,14 @@ def forward( return outputs # a, present, (attentions) -class GPTJFlashAttention2(GPTJAttention): +class GPTJFlashAttention(GPTJAttention): """ GPTJ flash attention module. This module inherits from `GPTJAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -327,6 +332,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -432,16 +438,26 @@ def forward( query_length = query.shape[1] # Compute attention - attn_weights = _flash_attention_forward( - query, - key, - value, - attention_mask, - query_length, - dropout=attention_dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_weights = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_length, + is_causal=self.is_causal, + ) + else: + attn_weights = _flash_attention_forward( + query, + key, + value, + attention_mask, + query_length, + dropout=attention_dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) # Reshape outputs attn_output = attn_weights.reshape( @@ -459,7 +475,8 @@ def forward( GPTJ_ATTENTION_CLASSES = { "eager": GPTJAttention, - "flash_attention_2": GPTJFlashAttention2, + "flash_attention_2": GPTJFlashAttention, + "flash_attention_3": GPTJFlashAttention, } @@ -540,6 +557,7 @@ class GPTJPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTJBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True @@ -720,6 +738,7 @@ def __init__(self, config): self.post_init() self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -946,7 +965,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index f62de411a4fa5f..c91a1d7158d18c 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -347,7 +348,7 @@ def forward( return attn_output, attn_weights, past_key_value -class GraniteFlashAttention2(GraniteAttention): +class GraniteFlashAttention(GraniteAttention): """ Granite flash attention module. This module inherits from `GraniteAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -361,6 +362,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -430,19 +432,31 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - softmax_scale=self.scaling, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + softmax_scale=self.scaling, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + softmax_scale=self.scaling, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -546,7 +560,8 @@ def forward( GRANITE_ATTENTION_CLASSES = { "eager": GraniteAttention, - "flash_attention_2": GraniteFlashAttention2, + "flash_attention_2": GraniteFlashAttention, + "flash_attention_3": GraniteFlashAttention, "sdpa": GraniteSdpaAttention, } @@ -662,6 +677,7 @@ class GranitePreTrainedModel(PreTrainedModel): _no_split_modules = ["GraniteDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -939,7 +955,10 @@ def _update_causal_mask( # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index da79c2894877b4..3681c2b84f5775 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -32,6 +32,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -42,6 +43,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -550,15 +553,15 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Hubert -class HubertFlashAttention2(HubertAttention): +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->Hubert +class HubertFlashAttention(HubertAttention): """ Hubert flash attention module. This module inherits from `HubertAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -566,6 +569,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -658,16 +662,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) @@ -788,7 +802,8 @@ def forward( HUBERT_ATTENTION_CLASSES = { "eager": HubertAttention, "sdpa": HubertSdpaAttention, - "flash_attention_2": HubertFlashAttention2, + "flash_attention_2": HubertFlashAttention, + "flash_attention_3": HubertFlashAttention, } @@ -936,6 +951,7 @@ def __init__(self, config): self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -952,7 +968,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1024,6 +1040,7 @@ def __init__(self, config): ) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1040,7 +1057,7 @@ def forward( # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1112,6 +1129,7 @@ class HubertPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 1289bda2d0fd3b..a36d5c15aead10 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1479,7 +1479,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/idefics/vision.py b/src/transformers/models/idefics/vision.py index 5339b706924d8f..8ee1c0f7d03bb8 100644 --- a/src/transformers/models/idefics/vision.py +++ b/src/transformers/models/idefics/vision.py @@ -164,7 +164,7 @@ def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: boo class IdeficsVisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: IdeficsVisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 41be300095e710..ba0c91664f1f4a 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -32,6 +32,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -44,6 +45,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -190,7 +193,7 @@ class Idefics2VisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ - def __init__(self, config): + def __init__(self, config: Idefics2VisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -265,14 +268,14 @@ def forward( return attn_output, attn_weights -class Idefics2VisionFlashAttention2(Idefics2VisionAttention): +class Idefics2VisionFlashAttention(Idefics2VisionAttention): """ Idefics2Vision flash attention module. This module inherits from `Idefics2VisionAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -280,6 +283,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -343,16 +347,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() attn_output = self.out_proj(attn_output) @@ -365,7 +379,8 @@ def forward( IDEFICS_VISION_ATTENTION_CLASSES = { "eager": Idefics2VisionAttention, - "flash_attention_2": Idefics2VisionFlashAttention2, + "flash_attention_2": Idefics2VisionFlashAttention, + "flash_attention_3": Idefics2VisionFlashAttention, } @@ -582,6 +597,7 @@ def __init__(self, config: Idefics2VisionConfig): self.encoder = Idefics2Encoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def get_input_embeddings(self): return self.embeddings @@ -623,7 +639,7 @@ def forward( # avoiding passing the attention_mask, which is equivalent to attending to the full sequence if not torch.any(~patch_attention_mask): patch_attention_mask = None - elif not self._use_flash_attention_2: + elif not self._use_flash_attention_2 and not self._use_flash_attention_3: patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( @@ -682,7 +698,7 @@ def extra_repr(self): class Idefics2PerceiverAttention(nn.Module): - def __init__(self, config, layer_idx: Optional[int] = None) -> None: + def __init__(self, config: Idefics2Config, layer_idx: Optional[int] = None) -> None: """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" super().__init__() @@ -782,15 +798,15 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2 -class Idefics2PerceiverFlashAttention2(Idefics2PerceiverAttention): +# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention with MistralAttention->Idefics2PerceiverAttention,MistralFlashAttention->Idefics2PerceiverFlashAttention,Mistral->Idefics2 +class Idefics2PerceiverFlashAttention(Idefics2PerceiverAttention): """ Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -798,6 +814,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" # Ignore copy def forward( @@ -888,17 +905,27 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - sliding_window=None, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=None, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() attn_output = self.o_proj(attn_output) @@ -911,7 +938,8 @@ def forward( IDEFICS2_PERCEIVER_ATTENTION_CLASSES = { "eager": Idefics2PerceiverAttention, - "flash_attention_2": Idefics2PerceiverFlashAttention2, + "flash_attention_2": Idefics2PerceiverFlashAttention, + "flash_attention_3": Idefics2PerceiverFlashAttention, } @@ -1010,6 +1038,7 @@ def __init__(self, config) -> None: self.norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1025,7 +1054,7 @@ def forward( attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1) attention_mask = ( _prepare_4d_attention_mask(attention_mask, latents.dtype, tgt_len=self.n_latents) - if not self._use_flash_attention_2 + if not self._use_flash_attention_2 and not self._use_flash_attention_3 else attention_mask ) @@ -1093,6 +1122,7 @@ class Idefics2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True def _init_weights(self, module): @@ -1227,6 +1257,7 @@ def __init__(self, config: Idefics2Config): self.image_token_id = self.config.image_token_id self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self.post_init() diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 60e1670a3c2784..3f28f0e40a1259 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -48,6 +48,7 @@ from ...utils.import_utils import ( is_causal_conv1d_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_mamba_ssm_available, is_torchdynamo_compiling, @@ -58,6 +59,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + if is_mamba_ssm_available(): from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn @@ -374,15 +378,15 @@ def forward( return attn_output, attn_weights, past_key_value -# Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba -class JambaFlashAttention2(JambaAttention): +# Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention with Mistral->Jamba +class JambaFlashAttention(JambaAttention): """ Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -390,6 +394,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -477,17 +482,27 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=getattr(self.config, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -583,7 +598,8 @@ def forward( JAMBA_ATTENTION_CLASSES = { "eager": JambaAttention, - "flash_attention_2": JambaFlashAttention2, + "flash_attention_2": JambaFlashAttention, + "flash_attention_3": JambaFlashAttention, "sdpa": JambaSdpaAttention, } @@ -1121,6 +1137,7 @@ class JambaPreTrainedModel(PreTrainedModel): _no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True @@ -1377,7 +1394,10 @@ def forward( ) def _update_causal_mask(self, attention_mask, input_tensor, cache_position): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 7c4394d0e1a168..b55488dfd358f0 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -36,6 +36,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -47,6 +48,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "jetmoe" @@ -688,8 +692,8 @@ def forward( return attn_output, None, past_key_value, router_logits -class JetMoeFlashAttention2(JetMoeAttention): - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ +class JetMoeFlashAttention(JetMoeAttention): + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -697,6 +701,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -782,16 +787,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ).to(input_dtype) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ).to(input_dtype) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ).to(input_dtype) # output projection attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size) @@ -806,7 +821,8 @@ def forward( JETMOE_ATTENTION_CLASSES = { "eager": JetMoeAttention, - "flash_attention_2": JetMoeFlashAttention2, + "flash_attention_2": JetMoeFlashAttention, + "flash_attention_3": JetMoeFlashAttention, "sdpa": JetMoeSdpaAttention, } @@ -879,6 +895,7 @@ class JetMoePreTrainedModel(PreTrainedModel): _no_split_modules = ["JetMoeBlock"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True @@ -1055,7 +1072,11 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + if ( + attention_mask is not None + and (self._attn_implementation == "flash_attention_2" or self._attn_implementation == "flash_attention_3") + and use_cache + ): batch_size = inputs_embeds.shape[0] is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: @@ -1143,7 +1164,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 69641790b2db8f..eb496bc81c19d5 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -416,7 +416,7 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: class Kosmos2VisionAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: Kosmos2VisionConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 0bc44f314b5e86..02a4e1c8b19fef 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -449,7 +450,7 @@ def forward( return attn_output, attn_weights, past_key_value -class LlamaFlashAttention2(LlamaAttention): +class LlamaFlashAttention(LlamaAttention): """ Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -463,6 +464,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -547,18 +549,29 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -670,7 +683,8 @@ def forward( LLAMA_ATTENTION_CLASSES = { "eager": LlamaAttention, - "flash_attention_2": LlamaFlashAttention2, + "flash_attention_2": LlamaFlashAttention, + "flash_attention_3": LlamaFlashAttention, "sdpa": LlamaSdpaAttention, } @@ -783,6 +797,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1042,7 +1057,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index eb1c55341b0784..99682d7818fde1 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -125,6 +125,7 @@ class LlavaPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True def _init_weights(self, module): diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index bf76921090b244..ddb9e2c0822544 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -234,6 +234,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlavaNextVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True def _init_weights(self, module): diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 589bf346ceeb9e..9c3f296b38b0ec 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -272,6 +272,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlavaNextVideoVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True def _init_weights(self, module): diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index d3200fb5193d4b..c3ac94b05a0053 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -239,6 +239,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlavaOnevisionVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_static_cache = False # Qwen2 doesn't but llava has no reasons to not support _supports_quantized_cache = True diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 23a855fff25672..c43e91f23f15ed 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -37,6 +37,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -47,6 +48,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -333,8 +336,8 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -class M2M100FlashAttention2(M2M100Attention): - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ +class M2M100FlashAttention(M2M100Attention): + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -342,6 +345,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -406,17 +410,28 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout, - softmax_scale=None, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + softmax_scale=None, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout, + softmax_scale=None, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. @@ -500,7 +515,8 @@ def forward( M2M100_ATTENTION_CLASSES = { "eager": M2M100Attention, - "flash_attention_2": M2M100FlashAttention2, + "flash_attention_2": M2M100FlashAttention, + "flash_attention_3": M2M100FlashAttention, } @@ -631,6 +647,7 @@ class M2M100PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module): std = self.config.init_std @@ -804,6 +821,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -884,7 +902,7 @@ def forward( # expand attention_mask if attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -980,6 +998,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([M2M100DecoderLayer(config) for _ in range(config.decoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1090,7 +1109,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers combined_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1101,7 +1120,7 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -1232,7 +1251,7 @@ def __init__(self, config: M2M100Config): self.encoder = M2M100Encoder(config, self.shared) self.decoder = M2M100Decoder(config, self.shared) - if config._attn_implementation == "flash_attention_2": + if config._attn_implementation == "flash_attention_2" or config._attn_implementation == "flash_attention_3": logger.warning_once( "Attention with Flash Attention 2 does not support `layer_head_mask`. If you need this feature, please use standard attention." ) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 9455f21b2073ff..9a9f7115769823 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -46,6 +46,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -56,6 +57,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -282,15 +285,15 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MBart -class MBartFlashAttention2(MBartAttention): +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->MBart +class MBartFlashAttention(MBartAttention): """ MBart flash attention module. This module inherits from `MBartAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -298,6 +301,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -390,16 +394,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) @@ -520,7 +534,8 @@ def forward( MBART_ATTENTION_CLASSES = { "eager": MBartAttention, "sdpa": MBartSdpaAttention, - "flash_attention_2": MBartFlashAttention2, + "flash_attention_2": MBartFlashAttention, + "flash_attention_3": MBartFlashAttention, } @@ -745,6 +760,7 @@ class MBartPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MBartDecoderLayer", "MBartAttention"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): @@ -1043,7 +1059,10 @@ def forward( # expand attention_mask if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): attention_mask = attention_mask if 0 in attention_mask else None elif self.config._attn_implementation == "sdpa" and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA, fall back to @@ -1260,7 +1279,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self.config._attn_implementation == "sdpa" and not output_attentions and cross_attn_head_mask is None: @@ -1280,7 +1302,10 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self.config._attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index d91b057ef28ec4..8559b29df0ee71 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -32,6 +32,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -42,6 +43,10 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + + logger = logging.get_logger(__name__) @@ -613,8 +618,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention2 with Gemma->Mimi -class MimiFlashAttention2(MimiAttention): +# Copied from transformers.models.gemma.modeling_gemma.GemmaFlashAttention with Gemma->Mimi +class MimiFlashAttention(MimiAttention): """ Mimi flash attention module. This module inherits from `MimiAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -628,6 +633,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -702,18 +708,29 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -815,7 +832,8 @@ def forward( MIMI_ATTENTION_CLASSES = { "eager": MimiAttention, - "flash_attention_2": MimiFlashAttention2, + "flash_attention_2": MimiFlashAttention, + "flash_attention_3": MimiFlashAttention, "sdpa": MimiSdpaAttention, } @@ -1090,7 +1108,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None @@ -1370,6 +1391,7 @@ class MimiPreTrainedModel(PreTrainedModel): _no_split_modules = ["MimiDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 80992734046ad6..6c27985c6e7465 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -41,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -52,6 +53,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MistralConfig" @@ -267,14 +271,14 @@ def forward( return attn_output, attn_weights, past_key_value -class MistralFlashAttention2(MistralAttention): +class MistralFlashAttention(MistralAttention): """ Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -282,6 +286,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -380,18 +385,29 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self.config, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() attn_output = self.o_proj(attn_output) @@ -494,7 +510,8 @@ def forward( MISTRAL_ATTENTION_CLASSES = { "eager": MistralAttention, - "flash_attention_2": MistralFlashAttention2, + "flash_attention_2": MistralFlashAttention, + "flash_attention_3": MistralFlashAttention, "sdpa": MistralSdpaAttention, } @@ -604,6 +621,7 @@ class MistralPreTrainedModel(PreTrainedModel): _no_split_modules = ["MistralDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True @@ -858,7 +876,7 @@ def _update_causal_mask( use_cache: bool, output_attentions: bool, ): - if self._attn_implementation == "flash_attention_2": + if self._attn_implementation == "flash_attention_2" or self._attn_implementation == "flash_attention_3": if attention_mask is not None and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index fcc0d66e19c4a4..3d9fad2cfdc225 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -43,6 +43,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -54,6 +55,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. if is_torch_fx_available(): @@ -430,9 +434,9 @@ def forward( return attn_output, attn_weights, past_key_value -# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral +# copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention with Mistral->Mixtral # TODO @longjie no longer copied from Mistral after static cache -class MixtralFlashAttention2(MixtralAttention): +class MixtralFlashAttention(MixtralAttention): """ Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -540,17 +544,28 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self.config, "sliding_window", None), + is_causal=self.is_causal, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -655,7 +670,8 @@ def forward( MIXTRAL_ATTENTION_CLASSES = { "eager": MixtralAttention, - "flash_attention_2": MixtralFlashAttention2, + "flash_attention_2": MixtralFlashAttention, + "flash_attention_3": MixtralFlashAttention, "sdpa": MixtralSdpaAttention, } @@ -857,6 +873,7 @@ class MixtralPreTrainedModel(PreTrainedModel): _no_split_modules = ["MixtralDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True @@ -1127,7 +1144,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index f720faac038e51..42982eacf83fdf 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -47,6 +47,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -59,6 +60,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + if TYPE_CHECKING: from ...generation.streamers import BaseStreamer @@ -311,15 +315,15 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Musicgen -class MusicgenFlashAttention2(MusicgenAttention): +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->Musicgen +class MusicgenFlashAttention(MusicgenAttention): """ Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -327,6 +331,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -419,16 +424,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) @@ -565,7 +580,8 @@ def forward( MUSICGEN_ATTENTION_CLASSES = { "eager": MusicgenAttention, "sdpa": MusicgenSdpaAttention, - "flash_attention_2": MusicgenFlashAttention2, + "flash_attention_2": MusicgenFlashAttention, + "flash_attention_3": MusicgenFlashAttention, } @@ -703,6 +719,7 @@ class MusicgenPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): @@ -998,7 +1015,7 @@ def forward( if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) - if self.attn_implementation == "flash_attention_2": + if self.attn_implementation == "flash_attention_2" or self.attn_implementation == "flash_attention_3": attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on @@ -1016,7 +1033,7 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.attn_implementation == "flash_attention_2": + if self.attn_implementation == "flash_attention_2" or self.attn_implementation == "flash_attention_3": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on @@ -1664,6 +1681,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): main_input_name = "input_ids" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__( diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index a8a8fe96098952..d4ea4e2254be76 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -39,6 +39,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -51,6 +52,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + if TYPE_CHECKING: from ...generation.streamers import BaseStreamer @@ -327,15 +331,15 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MusicgenMelody -class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention): +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->MusicgenMelody +class MusicgenMelodyFlashAttention(MusicgenMelodyAttention): """ MusicgenMelody flash attention module. This module inherits from `MusicgenMelodyAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -343,6 +347,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -435,16 +440,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) @@ -565,7 +580,8 @@ def forward( MUSICGEN_MELODY_ATTENTION_CLASSES = { "eager": MusicgenMelodyAttention, "sdpa": MusicgenMelodySdpaAttention, - "flash_attention_2": MusicgenMelodyFlashAttention2, + "flash_attention_2": MusicgenMelodyFlashAttention, + "flash_attention_3": MusicgenMelodyFlashAttention, } @@ -662,6 +678,7 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): @@ -945,7 +962,7 @@ def forward( input_shape = inputs_embeds.size()[:-1] - if self.attn_implementation == "flash_attention_2": + if self.attn_implementation == "flash_attention_2" or self.attn_implementation == "flash_attention_3": attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self.attn_implementation == "sdpa" and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on @@ -1590,6 +1607,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): main_input_name = "input_ids" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__( diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 4d079b4dde104d..46b61d56810e68 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -354,8 +355,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron -class NemotronFlashAttention2(NemotronAttention): +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron +class NemotronFlashAttention(NemotronAttention): """ Nemotron flash attention module. This module inherits from `NemotronAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -369,6 +370,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" # Ignore copy def forward( @@ -446,18 +448,29 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) @@ -562,7 +575,8 @@ def forward( NEMOTRON_ATTENTION_CLASSES = { "eager": NemotronAttention, - "flash_attention_2": NemotronFlashAttention2, + "flash_attention_2": NemotronFlashAttention, + "flash_attention_3": NemotronFlashAttention, "sdpa": NemotronSdpaAttention, } @@ -677,6 +691,7 @@ class NemotronPreTrainedModel(PreTrainedModel): _no_split_modules = ["NemotronDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -920,7 +935,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 03fa524532a02e..d48575b03dd0f3 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -41,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -52,6 +53,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -388,14 +392,14 @@ def forward( return attn_output, attn_weights, past_key_value -class OlmoFlashAttention2(OlmoAttention): +class OlmoFlashAttention(OlmoAttention): """ OLMo flash attention module. This module inherits from `OlmoAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -403,6 +407,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -477,17 +482,28 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -593,7 +609,8 @@ def forward( OLMO_ATTENTION_CLASSES = { "eager": OlmoAttention, - "flash_attention_2": OlmoFlashAttention2, + "flash_attention_2": OlmoFlashAttention, + "flash_attention_3": OlmoFlashAttention, "sdpa": OlmoSdpaAttention, } @@ -704,6 +721,7 @@ class OlmoPreTrainedModel(PreTrainedModel): _no_split_modules = ["OlmoDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -962,7 +980,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 2cbde7dc863169..56565a7f1acb72 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -34,6 +34,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -44,6 +45,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -463,14 +467,14 @@ def forward( return attn_output, attn_weights, past_key_value -class OlmoeFlashAttention2(OlmoeAttention): +class OlmoeFlashAttention(OlmoeAttention): """ OLMoE flash attention module. This module inherits from `OlmoeAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -478,6 +482,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -552,16 +557,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -669,7 +684,8 @@ def forward( OLMOE_ATTENTION_CLASSES = { "eager": OlmoeAttention, - "flash_attention_2": OlmoeFlashAttention2, + "flash_attention_2": OlmoeFlashAttention, + "flash_attention_3": OlmoeFlashAttention, "sdpa": OlmoeSdpaAttention, } @@ -838,6 +854,7 @@ class OlmoePreTrainedModel(PreTrainedModel): _no_split_modules = ["OlmoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1114,7 +1131,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 8f058171778efc..ed95fa435c08c4 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -35,6 +35,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -45,6 +46,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -240,14 +244,14 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -class OptFlashAttention2(OPTAttention): +class OptFlashAttention(OPTAttention): """ OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -255,6 +259,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -339,16 +344,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - query_length, - dropout=attn_dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=attn_dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) attn_output = self.out_proj(attn_weights_reshaped) @@ -361,7 +376,8 @@ def forward( OPT_ATTENTION_CLASSES = { "eager": OPTAttention, - "flash_attention_2": OptFlashAttention2, + "flash_attention_2": OptFlashAttention, + "flash_attention_3": OptFlashAttention, } @@ -488,6 +504,7 @@ class OPTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["OPTDecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module): std = self.config.init_std @@ -604,6 +621,7 @@ def __init__(self, config: OPTConfig): self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -702,7 +720,7 @@ def forward( mask_seq_length = past_key_values_length + seq_length # embed positions - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = ( diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 48fffb6b428df7..03c5f8ed7c58c9 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -192,6 +192,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): _no_split_modules = ["PaliGemmaMultiModalProjector"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = False + _supports_flash_attn_3 = False _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index a6fd2284afb691..ac6e8c300f507b 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -788,7 +788,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 4d0a076b5f9a33..833308d8afc08d 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -41,6 +41,7 @@ add_start_docstrings_to_model_forward, get_torch_version, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -52,6 +53,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -435,14 +438,14 @@ def forward( return attn_output, attn_weights, past_key_value -class PhiFlashAttention2(PhiAttention): +class PhiFlashAttention(PhiAttention): """ Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -450,6 +453,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -553,18 +557,30 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=attn_dropout, - softmax_scale=None, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + softmax_scale=None, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=attn_dropout, + softmax_scale=None, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.dense(attn_output) @@ -703,7 +719,8 @@ def forward( PHI_ATTENTION_CLASSES = { "eager": PhiAttention, - "flash_attention_2": PhiFlashAttention2, + "flash_attention_2": PhiFlashAttention, + "flash_attention_3": PhiFlashAttention, "sdpa": PhiSdpaAttention, } @@ -812,6 +829,7 @@ class PhiPreTrainedModel(PreTrainedModel): _no_split_modules = ["PhiDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True @@ -930,6 +948,7 @@ def __init__(self, config: PhiConfig): self.rotary_emb = PhiRotaryEmbedding(config=config) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.gradient_checkpointing = False @@ -1080,7 +1099,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index e0ca84be184843..ecdad9598b303a 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -39,6 +39,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -50,6 +51,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" @@ -481,14 +485,14 @@ def forward( return attn_output, attn_weights, past_key_value -class Phi3FlashAttention2(Phi3Attention): +class Phi3FlashAttention(Phi3Attention): """ Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -496,6 +500,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -610,18 +615,29 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=attn_dropout, - sliding_window=getattr(self.config, "sliding_window", None), - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=attn_dropout, + sliding_window=getattr(self.config, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -728,7 +744,8 @@ def forward( PHI3_ATTENTION_CLASSES = { "eager": Phi3Attention, - "flash_attention_2": Phi3FlashAttention2, + "flash_attention_2": Phi3FlashAttention, + "flash_attention_3": Phi3FlashAttention, "sdpa": Phi3SdpaAttention, } @@ -842,6 +859,7 @@ class Phi3PreTrainedModel(PreTrainedModel): _no_split_modules = ["Phi3DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True @@ -1101,7 +1119,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 93d91e160089e1..03a24a4676261f 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -685,6 +685,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) @@ -773,7 +774,7 @@ def forward( # expand attention_mask if attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None elif self._use_sdpa and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA, fall back to @@ -870,6 +871,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -990,7 +992,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: @@ -1010,7 +1012,7 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index aafecb95b6aafe..b151770293dedb 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -42,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -53,6 +54,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -394,7 +398,7 @@ def forward( return attn_output, attn_weights, past_key_value -class Qwen2FlashAttention2(Qwen2Attention): +class Qwen2FlashAttention(Qwen2Attention): """ Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` as the weights of the module stays untouched. The only required change would be on the forward pass @@ -403,7 +407,7 @@ class Qwen2FlashAttention2(Qwen2Attention): config.max_window_layers layers. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -411,6 +415,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -517,18 +522,29 @@ def forward( else: sliding_window = None - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -637,7 +653,8 @@ def forward( QWEN2_ATTENTION_CLASSES = { "eager": Qwen2Attention, - "flash_attention_2": Qwen2FlashAttention2, + "flash_attention_2": Qwen2FlashAttention, + "flash_attention_3": Qwen2FlashAttention, "sdpa": Qwen2SdpaAttention, } @@ -647,7 +664,11 @@ def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - if config.sliding_window and config._attn_implementation != "flash_attention_2": + if ( + config.sliding_window + and config._attn_implementation != "flash_attention_2" + and config._attn_implementation != "flash_attention_3" + ): logger.warning_once( f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." @@ -754,6 +775,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1018,7 +1040,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 14235bf0aaf64a..a9da44c991e517 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -30,6 +30,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -41,6 +42,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -214,15 +218,15 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention2 with Whisper->Qwen2Audio -class Qwen2AudioFlashAttention2(Qwen2AudioAttention): +# Copied from transformers.models.whisper.modeling_whisper.WhisperFlashAttention with Whisper->Qwen2Audio +class Qwen2AudioFlashAttention(Qwen2AudioAttention): """ Qwen2Audio flash attention module. This module inherits from `Qwen2AudioAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -230,6 +234,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -318,16 +323,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - causal_mask, - tgt_len, - dropout=self.dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + causal_mask, + tgt_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + causal_mask, + tgt_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, tgt_len, -1) attn_output = self.out_proj(attn_output) @@ -439,7 +454,8 @@ def forward( QWEN2AUDIO_ATTENTION_CLASSES = { "eager": Qwen2AudioAttention, - "flash_attention_2": Qwen2AudioFlashAttention2, + "flash_attention_2": Qwen2AudioFlashAttention, + "flash_attention_3": Qwen2AudioFlashAttention, "sdpa": Qwen2AudioSdpaAttention, } @@ -543,6 +559,7 @@ class Qwen2AudioPreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2AudioAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module): # important: this ported version of Qwen2Audio isn't meant for training from scratch - only diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index bc06b406bf43ed..80d46feaeffe47 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -43,6 +43,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -54,6 +55,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "Qwen/Qwen1.5-MoE-A2.7B" @@ -475,8 +479,8 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention2 with Qwen2->Qwen2Moe -class Qwen2MoeFlashAttention2(Qwen2MoeAttention): +# Copied from transformers.models.qwen2.modeling_qwen2.Qwen2FlashAttention with Qwen2->Qwen2Moe +class Qwen2MoeFlashAttention(Qwen2MoeAttention): """ Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention` as the weights of the module stays untouched. The only required change would be on the forward pass @@ -485,7 +489,7 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention): config.max_window_layers layers. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -493,6 +497,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -599,18 +604,29 @@ def forward( else: sliding_window = None - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -720,7 +736,8 @@ def forward( QWEN2MOE_ATTENTION_CLASSES = { "eager": Qwen2MoeAttention, - "flash_attention_2": Qwen2MoeFlashAttention2, + "flash_attention_2": Qwen2MoeFlashAttention, + "flash_attention_3": Qwen2MoeFlashAttention, "sdpa": Qwen2MoeSdpaAttention, } @@ -913,6 +930,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2MoeDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True @@ -1193,7 +1211,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 8e46f6840ad187..6fefa4999bf873 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -44,6 +44,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -58,6 +59,12 @@ else: flash_attn_varlen_func = None +if is_flash_attn_3_available(): + from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func + + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward +else: + flash_attn_3_varlen_func = None logger = logging.get_logger(__name__) @@ -322,8 +329,9 @@ def forward(self, x) -> torch.Tensor: class VisionAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: + def __init__(self, dim: int, num_heads: int = 16, config: Optional[Qwen2VLVisionConfig] = None) -> None: super().__init__() + self.config = config self.num_heads = num_heads self.head_dim = dim // num_heads self.qkv = nn.Linear(dim, dim * 3, bias=True) @@ -356,12 +364,10 @@ def forward( return attn_output -class VisionFlashAttention2(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) +class VisionFlashAttention(VisionAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None @@ -372,19 +378,21 @@ def forward( k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() - attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( - seq_length, -1 - ) + if self._flash_attn_3: + attn_output = flash_attn_3_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + else: + attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) attn_output = self.proj(attn_output) return attn_output -class VisionSdpaAttention(nn.Module): - def __init__(self, dim: int, num_heads: int = 16) -> None: - super().__init__() - self.num_heads = num_heads - self.qkv = nn.Linear(dim, dim * 3, bias=True) - self.proj = nn.Linear(dim, dim) +class VisionSdpaAttention(VisionAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None @@ -409,7 +417,8 @@ def forward( QWEN2_VL_VISION_ATTENTION_CLASSES = { "eager": VisionAttention, - "flash_attention_2": VisionFlashAttention2, + "flash_attention_2": VisionFlashAttention, + "flash_attention_3": VisionFlashAttention, "sdpa": VisionSdpaAttention, } @@ -422,7 +431,7 @@ def __init__(self, config, attn_implementation: str = "sdpa") -> None: mlp_hidden_dim = int(config.embed_dim * config.mlp_ratio) self.attn = QWEN2_VL_VISION_ATTENTION_CLASSES[attn_implementation]( - config.embed_dim, num_heads=config.num_heads + config.embed_dim, num_heads=config.num_heads, config=config ) self.mlp = VisionMlp(dim=config.embed_dim, hidden_dim=mlp_hidden_dim, hidden_act=config.hidden_act) @@ -656,7 +665,7 @@ def forward( return attn_output, attn_weights, past_key_value -class Qwen2VLFlashAttention2(Qwen2VLAttention): +class Qwen2VLFlashAttention(Qwen2VLAttention): """ Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention` as the weights of the module stays untouched. The only required change would be on the forward pass @@ -672,6 +681,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -791,17 +801,27 @@ def forward( else: sliding_window = None - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - sliding_window=sliding_window, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + sliding_window=sliding_window, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -916,7 +936,8 @@ def forward( QWEN2_VL_ATTENTION_CLASSES = { "eager": Qwen2VLAttention, - "flash_attention_2": Qwen2VLFlashAttention2, + "flash_attention_2": Qwen2VLFlashAttention, + "flash_attention_3": Qwen2VLFlashAttention, "sdpa": Qwen2VLSdpaAttention, } @@ -926,7 +947,11 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + if ( + config.use_sliding_window + and config._attn_implementation != "flash_attention_2" + and config._attn_implementation != "flash_attention_3" + ): logger.warning_once( f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." @@ -1033,6 +1058,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True @@ -1276,7 +1302,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index a8f076fad79c76..313471a599651f 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -540,6 +540,7 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): _no_split_modules = ["RecurrentGemmaDecoderLayer"] _skip_keys_device_placement = ["cache"] _supports_flash_attn_2 = False + _supports_flash_attn_3 = False _supports_sdpa = False # we can't compare with eager for now _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index c9a3494b88b486..c61677138322af 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -33,6 +33,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, ) @@ -43,6 +44,9 @@ from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -554,15 +558,15 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->SEW -class SEWFlashAttention2(SEWAttention): +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->SEW +class SEWFlashAttention(SEWAttention): """ SEW flash attention module. This module inherits from `SEWAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -570,6 +574,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -662,16 +667,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) @@ -792,7 +807,8 @@ def forward( SEW_ATTENTION_CLASSES = { "eager": SEWAttention, "sdpa": SEWSdpaAttention, - "flash_attention_2": SEWFlashAttention2, + "flash_attention_2": SEWFlashAttention, + "flash_attention_3": SEWFlashAttention, } @@ -869,6 +885,7 @@ def __init__(self, config): self.upsample = SEWUpsampling(config) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -882,7 +899,7 @@ def forward( all_self_attentions = () if output_attentions else None if attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # make sure padded tokens output 0 hidden_states[~attention_mask] = 0.0 # 2d mask is passed through the layers @@ -979,6 +996,7 @@ class SEWPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 1d35d1d44cfd97..31daeec45b0faa 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -35,6 +35,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -46,6 +47,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -355,8 +359,8 @@ def forward( class SiglipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ - def __init__(self, config): + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ with CLIP->Siglip + def __init__(self, config: SiglipConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -428,7 +432,7 @@ def forward( return attn_output, attn_weights -class SiglipFlashAttention2(SiglipAttention): +class SiglipFlashAttention(SiglipAttention): """ SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of @@ -437,7 +441,7 @@ class SiglipFlashAttention2(SiglipAttention): is_causal = False - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -445,8 +449,9 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" - # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward + # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention.forward def forward( self, hidden_states: torch.Tensor, @@ -502,16 +507,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=dropout_rate, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() attn_output = self.out_proj(attn_output) @@ -590,7 +605,8 @@ def forward( SIGLIP_ATTENTION_CLASSES = { "eager": SiglipAttention, - "flash_attention_2": SiglipFlashAttention2, + "flash_attention_2": SiglipFlashAttention, + "flash_attention_3": SiglipFlashAttention, "sdpa": SiglipSdpaAttention, } @@ -677,6 +693,7 @@ class SiglipPreTrainedModel(PreTrainedModel): "SiglipMultiheadAttentionPoolingHead", ] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): @@ -930,6 +947,7 @@ def __init__(self, config: SiglipTextConfig): self.head = nn.Linear(embed_dim, embed_dim) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) @@ -962,7 +980,7 @@ def forward( # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. # expand attention_mask - if attention_mask is not None and not self._use_flash_attention_2: + if attention_mask is not None and not self._use_flash_attention_2 and not self._use_flash_attention_3: # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 13641ecb37f255..8c037a36d0ebe6 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -42,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -52,6 +53,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -559,14 +563,14 @@ def forward( return attn_output, None, past_key_value -class StableLmFlashAttention2(StableLmAttention): +class StableLmFlashAttention(StableLmAttention): """ StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -574,6 +578,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -651,17 +656,28 @@ def forward( dropout_rate = self.attention_dropout.p if self.training else 0.0 - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - is_causal=self.is_causal, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -675,7 +691,8 @@ def forward( ATTENTION_CLASSES = { "eager": StableLmAttention, "sdpa": StableLmSdpaAttention, - "flash_attention_2": StableLmFlashAttention2, + "flash_attention_2": StableLmFlashAttention, + "flash_attention_3": StableLmFlashAttention, } @@ -799,6 +816,7 @@ class StableLmPreTrainedModel(PreTrainedModel): _no_split_modules = ["StableLmDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_sdpa = True _supports_quantized_cache = True @@ -1063,7 +1081,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 5eaf50f090fa49..c248d3346b0911 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -42,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -53,6 +54,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -376,14 +380,14 @@ def forward( return attn_output, attn_weights, past_key_value -class Starcoder2FlashAttention2(Starcoder2Attention): +class Starcoder2FlashAttention(Starcoder2Attention): """ Starcoder2 flash attention module. This module inherits from `Starcoder2Attention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -391,6 +395,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" # Ignore copy def forward( @@ -489,18 +494,29 @@ def forward( key_states = key_states.transpose(1, 2) value_states = value_states.transpose(1, 2) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - position_ids=position_ids, - dropout=dropout_rate, - sliding_window=getattr(self.config, "sliding_window", None), - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=dropout_rate, + sliding_window=getattr(self.config, "sliding_window", None), + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() attn_output = self.o_proj(attn_output) @@ -613,7 +629,8 @@ def forward( STARCODER2_ATTENTION_CLASSES = { "eager": Starcoder2Attention, - "flash_attention_2": Starcoder2FlashAttention2, + "flash_attention_2": Starcoder2FlashAttention, + "flash_attention_3": Starcoder2FlashAttention, "sdpa": Starcoder2SdpaAttention, } @@ -728,6 +745,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Starcoder2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -993,7 +1011,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 4202f680437c53..c8bd2485b320c6 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -35,6 +35,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -45,6 +46,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -586,15 +590,15 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->UniSpeech -class UniSpeechFlashAttention2(UniSpeechAttention): +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->UniSpeech +class UniSpeechFlashAttention(UniSpeechAttention): """ UniSpeech flash attention module. This module inherits from `UniSpeechAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -602,6 +606,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -694,16 +699,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) @@ -824,7 +839,8 @@ def forward( UNISPEECH_ATTENTION_CLASSES = { "eager": UniSpeechAttention, "sdpa": UniSpeechSdpaAttention, - "flash_attention_2": UniSpeechFlashAttention2, + "flash_attention_2": UniSpeechFlashAttention, + "flash_attention_3": UniSpeechFlashAttention, } @@ -972,6 +988,7 @@ def __init__(self, config): self.layers = nn.ModuleList([UniSpeechEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -988,7 +1005,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1060,6 +1077,7 @@ def __init__(self, config): ) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1076,7 +1094,7 @@ def forward( # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1218,6 +1236,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index bfb2cbfa4f55da..fd66caa5530f04 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -42,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_peft_available, logging, @@ -53,6 +54,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -603,15 +607,15 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->UniSpeechSat -class UniSpeechSatFlashAttention2(UniSpeechSatAttention): +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->UniSpeechSat +class UniSpeechSatFlashAttention(UniSpeechSatAttention): """ UniSpeechSat flash attention module. This module inherits from `UniSpeechSatAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -619,6 +623,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -711,16 +716,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) @@ -841,7 +856,8 @@ def forward( UNISPEECHSAT_ATTENTION_CLASSES = { "eager": UniSpeechSatAttention, "sdpa": UniSpeechSatSdpaAttention, - "flash_attention_2": UniSpeechSatFlashAttention2, + "flash_attention_2": UniSpeechSatFlashAttention, + "flash_attention_3": UniSpeechSatFlashAttention, } @@ -989,6 +1005,7 @@ def __init__(self, config): self.layers = nn.ModuleList([UniSpeechSatEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1005,7 +1022,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1077,6 +1094,7 @@ def __init__(self, config): ) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1093,7 +1111,7 @@ def forward( # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1235,6 +1253,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 9ae80be65ae4b6..7fbc36945d9a0a 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -126,6 +126,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): _no_split_modules = ["VideoLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True def _init_weights(self, module): diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 53a3213697193e..f42e9333c39bec 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -132,6 +132,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel): _no_split_modules = ["VipLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True def _init_weights(self, module): diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index f1d021b58ee538..a0270d852fad52 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -45,6 +45,7 @@ add_start_docstrings_to_model_forward, cached_file, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_peft_available, is_safetensors_available, @@ -64,6 +65,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -650,15 +654,15 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Wav2Vec2 -class Wav2Vec2FlashAttention2(Wav2Vec2Attention): +# Copied from transformers.models.bart.modeling_bart.BartFlashAttention with Bart->Wav2Vec2 +class Wav2Vec2FlashAttention(Wav2Vec2Attention): """ Wav2Vec2 flash attention module. This module inherits from `Wav2Vec2Attention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -666,6 +670,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) @@ -758,16 +763,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) @@ -888,7 +903,8 @@ def forward( WAV2VEC2_ATTENTION_CLASSES = { "eager": Wav2Vec2Attention, "sdpa": Wav2Vec2SdpaAttention, - "flash_attention_2": Wav2Vec2FlashAttention2, + "flash_attention_2": Wav2Vec2FlashAttention, + "flash_attention_3": Wav2Vec2FlashAttention, } @@ -1006,6 +1022,7 @@ def __init__(self, config): self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1022,7 +1039,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1093,6 +1110,7 @@ def __init__(self, config): ) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1109,7 +1127,7 @@ def forward( # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1331,6 +1349,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index b82b978e5e6d95..de694cf8858cf3 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -39,6 +39,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -50,6 +51,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -400,14 +404,14 @@ def forward( return attn_output, attn_weights, past_key_value -class WhisperFlashAttention2(WhisperAttention): +class WhisperFlashAttention(WhisperAttention): """ Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays untouched. The only required change would be on the forward pass where it needs to correctly call the public API of flash attention and deal with padding tokens in case the input contains any of them. """ - # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention.__init__ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -415,6 +419,7 @@ def __init__(self, *args, **kwargs): # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_3 = self.config._attn_implementation == "flash_attention_3" def forward( self, @@ -503,16 +508,26 @@ def forward( key_states = key_states.to(target_dtype) value_states = value_states.to(target_dtype) - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - causal_mask, - tgt_len, - dropout=self.dropout, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) + if self._flash_attn_3: + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + causal_mask, + tgt_len, + is_causal=self.is_causal, + ) + else: + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + causal_mask, + tgt_len, + dropout=self.dropout, + is_causal=self.is_causal, + use_top_left_mask=self._flash_attn_uses_top_left_mask, + ) attn_output = attn_output.reshape(bsz, tgt_len, -1) attn_output = self.out_proj(attn_output) @@ -623,7 +638,8 @@ def forward( WHISPER_ATTENTION_CLASSES = { "eager": WhisperAttention, - "flash_attention_2": WhisperFlashAttention2, + "flash_attention_2": WhisperFlashAttention, + "flash_attention_3": WhisperFlashAttention, "sdpa": WhisperSdpaAttention, } @@ -823,6 +839,7 @@ class WhisperPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True @@ -1166,6 +1183,7 @@ def __init__(self, config: WhisperConfig): [WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)] ) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layer_norm = nn.LayerNorm(config.d_model) @@ -1428,7 +1446,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/x_clip/modeling_x_clip.py b/src/transformers/models/x_clip/modeling_x_clip.py index 791e501d173721..c215af48eb056f 100644 --- a/src/transformers/models/x_clip/modeling_x_clip.py +++ b/src/transformers/models/x_clip/modeling_x_clip.py @@ -171,7 +171,7 @@ def forward( class XCLIPAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: XCLIPConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index b86e3af91ca727..dd9294adff3b17 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -71,6 +71,7 @@ is_faiss_available, is_fbgemm_gpu_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_flax_available, is_fsdp_available, is_ftfy_available, @@ -519,6 +520,16 @@ def require_flash_attn(test_case): return unittest.skipUnless(is_flash_attn_2_available(), "test requires Flash Attention")(test_case) +def require_flash_attn_3(test_case): + """ + Decorator marking a test that requires Flash Attention 3. + + These tests are skipped when Flash Attention 3 isn't installed. + + """ + return unittest.skipUnless(is_flash_attn_3_available(), "test requires Flash Attention 3")(test_case) + + def require_torch_sdpa(test_case): """ Decorator marking a test that requires PyTorch's SDPA. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index eee350349f5565..92a9b1fb675623 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -130,6 +130,7 @@ is_faiss_available, is_fbgemm_gpu_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_flax_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ad8b649aaa4e84..4f4e70b122be6f 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -62,6 +62,14 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ except ImportError: # If the package can't be imported, it's not available package_exists = False + elif pkg_name == "flash_attn_interface": + try: + package = importlib.import_module(pkg_name) + package_version = getattr(package, "__version__", "N/A") + package_exists = True + except ImportError: + # If the package can't be imported, it's not available + package_exists = False else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False @@ -899,6 +907,16 @@ def is_flash_attn_greater_or_equal(library_version: str): return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) +def is_flash_attn_3_available(): + if not is_flash_attn_2_available(): + return False + + if not _is_package_available("flash_attn_interface"): + return False + + return True + + def is_torchdistx_available(): return _torchdistx_available diff --git a/tests/models/bark/test_modeling_bark.py b/tests/models/bark/test_modeling_bark.py index 9bb8ef33d75998..9c6064cf55a294 100644 --- a/tests/models/bark/test_modeling_bark.py +++ b/tests/models/bark/test_modeling_bark.py @@ -35,6 +35,7 @@ ) from transformers.testing_utils import ( require_flash_attn, + require_flash_attn_3, require_torch, require_torch_fp16, require_torch_gpu, @@ -981,6 +982,63 @@ def test_flash_attn_2_inference_equivalence(self): model.train() _ = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(reason="Model does not support flash_attention_3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + dummy_input = inputs_dict["input_ids"][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + dummy_attention_mask = dummy_attention_mask[:1] + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + + outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True) + outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + other_inputs = {"output_hidden_states": True} + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs) + outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + _ = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -1039,6 +1097,64 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence_right_padding(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(reason="Model does not support flash_attention_3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.bfloat16, + ) + model.to(torch_device) + + dummy_input = inputs_dict["input_ids"][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + dummy_attention_mask = dummy_attention_mask[:1] + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + outputs = model(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True) + outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, output_hidden_states=True) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(inputs_dict["codebook_idx"], dummy_input, **other_inputs) + outputs_fa = model_fa(inputs_dict["codebook_idx"], dummy_input, **other_inputs) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_torch class BarkModelIntegrationTests(unittest.TestCase): diff --git a/tests/models/chameleon/test_modeling_chameleon.py b/tests/models/chameleon/test_modeling_chameleon.py index 16e0a548e6dc47..6635c1f91e9aaf 100644 --- a/tests/models/chameleon/test_modeling_chameleon.py +++ b/tests/models/chameleon/test_modeling_chameleon.py @@ -24,6 +24,7 @@ from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, + require_flash_attn_3, require_read_token, require_torch, require_torch_gpu, @@ -366,6 +367,43 @@ def test_flash_attn_2_generate_padding_right(self): self.assertListEqual(output_native, output_fa_2) + @require_flash_attn_3 + @require_read_token + @require_torch_gpu + @require_bitsandbytes + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_padding_right(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + model = ChameleonForConditionalGeneration.from_pretrained( + "facebook/chameleon-7b", + load_in_4bit=True, + device_map={"": 0}, + ) + + processor = ChameleonProcessor.from_pretrained("facebook/chameleon-7b") + texts = ["hi", "Hello this is a very long sentence"] + + processor.tokenizer.padding_side = "right" + + inputs = processor(texts, return_tensors="pt", padding=True).to(0) + + output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_native = processor.tokenizer.batch_decode(output_native) + + model = ChameleonForConditionalGeneration.from_pretrained( + "facebook/chameleon-7b", + load_in_4bit=True, + attn_implementation="flash_attention_3", + ) + + output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_fa_3 = processor.tokenizer.batch_decode(output_fa_3) + + self.assertListEqual(output_native, output_fa_3) + @unittest.skip("Chameleon forces some token ids to be -inf!") def test_batching_equivalence(self): pass diff --git a/tests/models/clip/test_modeling_clip.py b/tests/models/clip/test_modeling_clip.py index 3b6994428088a2..a63a8236a087c3 100644 --- a/tests/models/clip/test_modeling_clip.py +++ b/tests/models/clip/test_modeling_clip.py @@ -31,6 +31,7 @@ is_flax_available, is_pt_flax_cross_test, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_gpu, require_torch_sdpa, @@ -1019,6 +1020,45 @@ def test_flash_attn_2_inference_equivalence(self): f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}", ) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16) + dummy_input_ids = inputs_dict["input_ids"] + + outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True) + outputs_fa = model_fa( + pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True + ) + + self.assertTrue( + torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2), + f"Image logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}", + ) + self.assertTrue( + torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2), + f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}", + ) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -1070,6 +1110,57 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): f"Text logits max diff: {torch.max(torch.abs(logits_per_text_eager - logits_per_text_sdpa))}", ) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + def test_flash_attn_3_inference_equivalence_right_padding(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="eager" + ) + model.to(torch_device) + + dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16) + dummy_input_ids = inputs_dict["input_ids"] + dummy_pixel_mask = inputs_dict["attention_mask"] + + # right padding + dummy_pixel_mask[:] = 1 + dummy_pixel_mask[:, -1:] = 0 + + outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True) + outputs_fa = model_fa( + pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True + ) + + logits_per_image_eager = outputs.logits_per_image[:, :-1] + logits_per_text_eager = outputs.logits_per_text[:, :-1] + + logits_per_image_sdpa = outputs_fa.logits_per_image[:, :-1] + logits_per_text_sdpa = outputs_fa.logits_per_text[:, :-1] + + self.assertTrue( + torch.allclose(logits_per_image_eager, logits_per_image_sdpa, atol=4e-2, rtol=4e-2), + f"Image logits max diff: {torch.max(torch.abs(logits_per_image_eager - logits_per_image_sdpa))}", + ) + self.assertTrue( + torch.allclose(logits_per_text_eager, logits_per_text_sdpa, atol=4e-2, rtol=4e-2), + f"Text logits max diff: {torch.max(torch.abs(logits_per_text_eager - logits_per_text_sdpa))}", + ) + class CLIPForImageClassificationModelTester(CLIPModelTester): def __init__(self, parent): diff --git a/tests/models/distilbert/test_modeling_distilbert.py b/tests/models/distilbert/test_modeling_distilbert.py index 3a74a1557cf9ba..00c5eb2a0e1ec0 100644 --- a/tests/models/distilbert/test_modeling_distilbert.py +++ b/tests/models/distilbert/test_modeling_distilbert.py @@ -19,7 +19,14 @@ import pytest from transformers import DistilBertConfig, is_torch_available -from transformers.testing_utils import require_flash_attn, require_torch, require_torch_accelerator, slow, torch_device +from transformers.testing_utils import ( + require_flash_attn, + require_flash_attn_3, + require_torch, + require_torch_accelerator, + slow, + torch_device, +) from ...test_configuration_common import ConfigTester from ...test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask @@ -348,6 +355,58 @@ def test_flash_attn_2_inference_equivalence(self): self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)) + # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test. + @require_flash_attn_3 + @require_torch_accelerator + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence(self): + import torch + + for model_class in self.all_model_classes: + dummy_input = torch.LongTensor( + [ + [1, 2, 3, 4], + [1, 2, 8, 9], + [1, 2, 11, 12], + [1, 2, 13, 14], + ] + ).to(torch_device) + dummy_attention_mask = torch.LongTensor( + [ + [0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + ] + ).to(torch_device) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + logits = model(dummy_input, output_hidden_states=True).hidden_states[-1] + logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)) + + output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits_fa = output_fa.hidden_states[-1] + + output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits = output.hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2)) + # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test. @require_flash_attn @require_torch_accelerator @@ -403,6 +462,61 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)) + # Because DistilBertForMultipleChoice requires inputs with different shapes we need to override this test. + @require_flash_attn_3 + @require_torch_accelerator + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence_right_padding(self): + import torch + + for model_class in self.all_model_classes: + dummy_input = torch.LongTensor( + [ + [1, 2, 3, 4], + [1, 2, 8, 9], + [1, 2, 11, 12], + [1, 2, 13, 14], + ] + ).to(torch_device) + dummy_attention_mask = torch.LongTensor( + [ + [0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + [0, 1, 1, 1], + ] + ).to(torch_device) + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.bfloat16, + ) + model.to(torch_device) + + logits = model(dummy_input, output_hidden_states=True).hidden_states[-1] + logits_fa = model_fa(dummy_input, output_hidden_states=True).hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2)) + + output_fa = model_fa(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits_fa = output_fa.hidden_states[-1] + + output = model(dummy_input, attention_mask=dummy_attention_mask, output_hidden_states=True) + logits = output.hidden_states[-1] + + self.assertTrue(torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2)) + @require_torch class DistilBertModelIntergrationTest(unittest.TestCase): diff --git a/tests/models/gemma/test_modeling_gemma.py b/tests/models/gemma/test_modeling_gemma.py index a02541d585447c..fb4a11ced8d53c 100644 --- a/tests/models/gemma/test_modeling_gemma.py +++ b/tests/models/gemma/test_modeling_gemma.py @@ -25,6 +25,7 @@ is_flaky, require_bitsandbytes, require_flash_attn, + require_flash_attn_3, require_read_token, require_torch, require_torch_accelerator, @@ -453,6 +454,51 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_use_cache(self): + import torch + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: Gemma apparently does not support right padding + use_cache with FA3. + dummy_attention_mask[:, -1] = 1 + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -460,6 +506,13 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest(reason="Gemma flash attention does not support right padding") + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence_right_padding(self): + self.skipTest(reason="Gemma flash attention 3 does not support right padding") + @require_torch_sdpa @require_torch_accelerator @slow @@ -526,6 +579,40 @@ def test_flash_attn_2_equivalence(self): # gemma flash attention 2 needs a high tolerance assert torch.allclose(logits_fa, logits, atol=3e-3) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @is_flaky() + @slow + def test_flash_attn_3_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(reason="Model does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager") + model.to(torch_device) + + dummy_input = inputs_dict[model_class.main_input_name] + dummy_input = dummy_input.to(torch_device) + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + # gemma flash attention 3 needs a high tolerance + assert torch.allclose(logits_fa, logits, atol=3e-3) + @slow @require_torch_accelerator @@ -652,6 +739,29 @@ def test_model_2b_flash_attn(self): self.assertEqual(output_text, EXPECTED_TEXTS) + @require_flash_attn_3 + @require_read_token + @pytest.mark.flash_attn_3_test + def test_model_2b_flash_attn_fa3(self): + model_id = "google/gemma-2b" + EXPECTED_TEXTS = [ + "Hello I am doing a project on the 1990s and I need to know what the most popular music", + "Hi today I am going to share with you a very easy and simple recipe of Kaju Kat", + ] + + model = AutoModelForCausalLM.from_pretrained( + model_id, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model.to(torch_device) + + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + + self.assertEqual(output_text, EXPECTED_TEXTS) + @require_bitsandbytes @require_read_token def test_model_2b_4bit(self): diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 4e7b3553460f89..75a09f1a83e73d 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -22,6 +22,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, Gemma2Config, HybridCache, is_torch_available, pipeline from transformers.testing_utils import ( require_flash_attn, + require_flash_attn_3, require_read_token, require_torch, require_torch_gpu, @@ -306,3 +307,27 @@ def test_model_9b_flash_attn(self): output_text = tokenizer.batch_decode(output, skip_special_tokens=False) self.assertEqual(output_text, EXPECTED_TEXTS) + + @require_read_token + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + def test_model_9b_flash_attn_3(self): + # See https://github.com/huggingface/transformers/issues/31953 --- flash attn was generating garbage for gemma2, especially in long context + model_id = "google/gemma-2-9b" + EXPECTED_TEXTS = [ + 'Hello I am doing a project on the 1918 flu pandemic and I am trying to find out how many people died in the United States. I have found a few sites that say 500,000 but I am not sure if that is correct. I have also found a site that says 675,000 but I am not sure if that is correct either. I am trying to find out how many people died in the United States. I have found a few', + "Hi today I'm going to be talking about the history of the United States. The United States of America is a country in North America. It is the third largest country in the world by total area and the third most populous country with over 320 million people. The United States is a federal republic consisting of 50 states and a federal district. The 48 contiguous states and the district of Columbia are in central North America between Canada and Mexico. The state of Alaska is in the" + ] # fmt: skip + + model = AutoModelForCausalLM.from_pretrained( + model_id, attn_implementation="flash_attention_3", torch_dtype="float16" + ).to(torch_device) + tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = tokenizer(self.input_text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=100, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=False) + + self.assertEqual(output_text, EXPECTED_TEXTS) diff --git a/tests/models/gpt2/test_modeling_gpt2.py b/tests/models/gpt2/test_modeling_gpt2.py index 3f96c20ab2dbd9..30d3c4d1b9bab3 100644 --- a/tests/models/gpt2/test_modeling_gpt2.py +++ b/tests/models/gpt2/test_modeling_gpt2.py @@ -25,6 +25,7 @@ from transformers.testing_utils import ( backend_empty_cache, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_gpu, slow, @@ -939,3 +940,40 @@ def test_flash_attn_2_generate_padding_left(self): self.assertListEqual(output_native, output_fa_2) self.assertListEqual(output_native, expected_output) + + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_padding_left(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + model = GPT2LMHeadModel.from_pretrained("gpt2", torch_dtype=torch.float16).to(0) + + tokenizer = GPT2Tokenizer.from_pretrained("gpt2") + + texts = ["hi", "Hello this is a very long sentence"] + + tokenizer.padding_side = "left" + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) + + output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_native = tokenizer.batch_decode(output_native) + + model = GPT2LMHeadModel.from_pretrained( + "gpt2", device_map={"": 0}, attn_implementation="flash_attention_3", torch_dtype=torch.float16 + ) + + output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_fa_3 = tokenizer.batch_decode(output_fa_3) + + expected_output = [ + "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>hi, who was born in the city of Kolkata, was a member of the Kolkata", + "Hello this is a very long sentence. I'm sorry. I'm sorry. I'm sorry. I'm sorry. I'm sorry", + ] + + self.assertListEqual(output_native, output_fa_3) + self.assertListEqual(output_native, expected_output) diff --git a/tests/models/gptj/test_modeling_gptj.py b/tests/models/gptj/test_modeling_gptj.py index 2ef2e391215e7b..02fbb3703d8403 100644 --- a/tests/models/gptj/test_modeling_gptj.py +++ b/tests/models/gptj/test_modeling_gptj.py @@ -23,6 +23,7 @@ from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_gpu, slow, @@ -565,6 +566,44 @@ def test_flash_attn_2_generate_padding_right(self): self.assertListEqual(expected_outputs, output_fa_2) + @require_flash_attn_3 + @require_torch_gpu + @require_bitsandbytes + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_padding_right(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6b") + + texts = ["hi", "Hello this is a very long sentence"] + expected_outputs = [ + "hi<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>Q: I have a question about the new version of the game. I have a question about the", + "Hello this is a very long sentence.\n\nA:\n\nI think the best way to understand this is to think of it", + ] + + tokenizer.padding_side = "right" + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) + + quantization_config = BitsAndBytesConfig(load_in_4bit=True) + + model = GPTJForCausalLM.from_pretrained( + "EleutherAI/gpt-j-6b", + device_map={"": 0}, + attn_implementation="flash_attention_3", + revision="float16", + torch_dtype=torch.float16, + quantization_config=quantization_config, + ) + + output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_fa_3 = tokenizer.batch_decode(output_fa_3) + + self.assertListEqual(expected_outputs, output_fa_3) + @require_torch class GPTJModelLanguageGenerationTest(unittest.TestCase): diff --git a/tests/models/granite/test_modeling_granite.py b/tests/models/granite/test_modeling_granite.py index 8771cd50978a7f..933e3b798c7232 100644 --- a/tests/models/granite/test_modeling_granite.py +++ b/tests/models/granite/test_modeling_granite.py @@ -24,6 +24,7 @@ from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, + require_flash_attn_3, require_read_token, require_torch, require_torch_gpu, @@ -518,6 +519,46 @@ def test_flash_attn_2_generate_padding_right(self): self.assertListEqual(output_native, output_fa_2) + @require_flash_attn_3 + @require_torch_gpu + @require_bitsandbytes + @pytest.mark.flash_attn_3_test + @require_read_token + @slow + def test_flash_attn_3_generate_padding_right(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + model = GraniteForCausalLM.from_pretrained( + "ibm/PowerLM-3b", + load_in_4bit=True, + device_map={"": 0}, + ) + + tokenizer = AutoTokenizer.from_pretrained("ibm/PowerLM-3b") + + texts = ["hi", "Hello this is a very long sentence"] + + tokenizer.padding_side = "right" + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) + + output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_native = tokenizer.batch_decode(output_native) + + model = GraniteForCausalLM.from_pretrained( + "ibm/PowerLM-3b", + load_in_4bit=True, + device_map={"": 0}, + attn_implementation="flash_attention_3", + ) + + output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_fa_3 = tokenizer.batch_decode(output_fa_3) + + self.assertListEqual(output_native, output_fa_3) + @require_flash_attn @require_torch_gpu @slow diff --git a/tests/models/idefics2/test_modeling_idefics2.py b/tests/models/idefics2/test_modeling_idefics2.py index e02c5b4c9f09c6..51f8604aad17b2 100644 --- a/tests/models/idefics2/test_modeling_idefics2.py +++ b/tests/models/idefics2/test_modeling_idefics2.py @@ -32,6 +32,7 @@ from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_gpu, require_torch_multi_gpu, @@ -608,3 +609,37 @@ def test_flash_attn_2_eager_equivalence(self): ) self.assertEqual(generated_texts_eager[0], generated_texts_flash_attention_2[0]) + + @require_flash_attn_3 + @require_torch_gpu + @require_bitsandbytes + def test_flash_attn_3_eager_equivalence(self): + # Create inputs + text = "In this image, we see" + images = self.image1 + inputs = self.processor(text=text, images=images, return_tensors="pt", padding=True) + inputs.to(torch_device) + + # Eager model + model_eager = Idefics2ForConditionalGeneration.from_pretrained( + "HuggingFaceM4/idefics2-8b-base", + attn_implementation="eager", + load_in_4bit=True, + ) + generated_ids_eager = model_eager.generate(**inputs, max_new_tokens=10) + generated_texts_eager = self.processor.batch_decode(generated_ids_eager, skip_special_tokens=True) + + del model_eager + + # Flash Attention 3 model + model_flash_attention_3 = Idefics2ForConditionalGeneration.from_pretrained( + "HuggingFaceM4/idefics2-8b-base", + attn_implementation="flash_attention_3", + load_in_4bit=True, + ) + generated_ids_flash_attention_3 = model_flash_attention_3.generate(**inputs, max_new_tokens=10) + generated_texts_flash_attention_3 = self.processor.batch_decode( + generated_ids_flash_attention_3, skip_special_tokens=True + ) + + self.assertEqual(generated_texts_eager[0], generated_texts_flash_attention_3[0]) diff --git a/tests/models/jamba/test_modeling_jamba.py b/tests/models/jamba/test_modeling_jamba.py index 6e1a2cf2cf9c44..28fa0e8749c827 100644 --- a/tests/models/jamba/test_modeling_jamba.py +++ b/tests/models/jamba/test_modeling_jamba.py @@ -25,6 +25,7 @@ from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_gpu, slow, @@ -539,6 +540,45 @@ def test_flash_attn_2_fp32_ln(self): # with attention mask _ = model(dummy_input, attention_mask=dummy_attention_mask) + @require_flash_attn_3 + @require_torch_gpu + @require_bitsandbytes + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_fp32_ln(self): + r""" + Overriding the test_flash_attn_3_fp32_ln test as the Jamba model, like Mixtral, doesn't support + right padding + use cache with FA3 + """ + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_input = inputs_dict[model.main_input_name] + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: Jamba does not support right padding + use_cache with FA3. + dummy_attention_mask[:, -1] = 1 + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + load_in_4bit=True, + ) + + for _, param in model.named_parameters(): + # upcast only layer norms + if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): + param.data = param.data.to(torch.float32) + + _ = model(dummy_input) + # with attention mask + _ = model(dummy_input, attention_mask=dummy_attention_mask) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -577,6 +617,44 @@ def test_flash_attn_2_generate_padding_right(self): dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_padding_right(self): + r""" + Overriding the test_flash_attn_3_generate_padding_right test as the Jamba model, like Mixtral, doesn't support + right padding + use cache with FA3 + """ + import torch + + for model_class in self.all_generative_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + + model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + with self.assertRaises(ValueError): + _ = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -626,6 +704,55 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_use_cache(self): + r""" + Overriding the test_flash_attn_3_generate_use_cache test as the Jamba model, like Mixtral, doesn't support + right padding + use cache with FA3 + """ + import torch + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: Jamba does not support right padding + use_cache with FA3. + dummy_attention_mask[:, -1] = 1 + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -637,6 +764,17 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): """ self.skipTest(reason="Jamba flash attention does not support right padding") + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence_right_padding(self): + r""" + Overriding the test_flash_attn_3_inference_padding_right test as the Jamba model, like Mixtral, doesn't support + right padding + use cache with FA3 + """ + self.skipTest(reason="Jamba flash attention does not support right padding") + @unittest.skip(reason="Jamba has its own special cache type") @parameterized.expand([(1, False), (1, True), (4, False)]) def test_new_cache_format(self, num_beams, do_sample): diff --git a/tests/models/jetmoe/test_modeling_jetmoe.py b/tests/models/jetmoe/test_modeling_jetmoe.py index 50fd7a27e1e6d1..238d2d7219312f 100644 --- a/tests/models/jetmoe/test_modeling_jetmoe.py +++ b/tests/models/jetmoe/test_modeling_jetmoe.py @@ -26,6 +26,7 @@ backend_empty_cache, is_flaky, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_gpu, require_torch_sdpa, @@ -420,6 +421,40 @@ def test_flash_attn_2_generate_padding_right(self): dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_padding_right(self): + import torch + + for model_class in self.all_generative_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + + model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + with self.assertRaises(ValueError): + _ = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -465,6 +500,51 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_use_cache(self): + import torch + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: JetMoe apparently does not support right padding + use_cache with FA3. + dummy_attention_mask[:, -1] = 1 + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -472,6 +552,13 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest(reason="JetMoe flash attention does not support right padding") + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence_right_padding(self): + self.skipTest(reason="JetMoe flash attention does not support right padding") + @require_torch class JetMoeIntegrationTest(unittest.TestCase): diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index c99357ff99b257..f03352b10d4f09 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -26,6 +26,7 @@ from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, + require_flash_attn_3, require_read_token, require_torch, require_torch_gpu, @@ -617,6 +618,43 @@ def test_flash_attn_2_generate_padding_right(self): self.assertListEqual(output_native, output_fa_2) + @require_flash_attn_3 + @require_torch_gpu + @require_bitsandbytes + @pytest.mark.flash_attn_3_test + @require_read_token + @slow + def test_flash_attn_3_generate_padding_right(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + model = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", + load_in_4bit=True, + device_map={"": 0}, + ) + + tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + + texts = ["hi", "Hello this is a very long sentence"] + + tokenizer.padding_side = "right" + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) + + output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_native = tokenizer.batch_decode(output_native) + + model = LlamaForCausalLM.from_pretrained( + "meta-llama/Llama-2-7b-hf", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_3" + ) + + output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_fa_3 = tokenizer.batch_decode(output_fa_3) + + self.assertListEqual(output_native, output_fa_3) + @require_flash_attn @require_torch_gpu @slow diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py index a29a9c8a9ec0dc..35701f0c0eb9d8 100644 --- a/tests/models/m2m_100/test_modeling_m2m_100.py +++ b/tests/models/m2m_100/test_modeling_m2m_100.py @@ -23,6 +23,7 @@ from transformers import M2M100Config, is_torch_available from transformers.testing_utils import ( require_flash_attn, + require_flash_attn_3, require_sentencepiece, require_tokenizers, require_torch, @@ -465,3 +466,48 @@ def test_flash_attn_2_seq_to_seq_generation(self): hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True ) assert generated == expected_en + + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_seq_to_seq_generation(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + model = M2M100ForConditionalGeneration.from_pretrained( + "facebook/m2m100_418M", attn_implementation="flash_attention_3" + ).to(torch_device) + + tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en") + + src_fr = [ + "L'affaire NSA souligne l'absence totale de débat sur le renseignement", + "Selon moi, il y a deux niveaux de réponse de la part du gouvernement français.", + "Lorsque François Hollande téléphone à Barack Obama ou quand le ministre des affaires étrangères Laurent" + " Fabius convoque l'ambassadeur des Etats-Unis, ils réagissent à une vraie découverte, qui est celle de" + " l'ampleur de la surveillance américaine sur l'ensemble des communications en France.", + ] + + # The below article tests that we don't add any hypotheses outside of the top n_beams + dct = tokenizer(src_fr, padding=True, return_tensors="pt") + + hypotheses_batch = model.generate( + input_ids=dct["input_ids"].to(torch_device), + attention_mask=dct["attention_mask"].to(torch_device), + num_beams=5, + forced_bos_token_id=tokenizer.get_lang_id("en"), + ) + + expected_en = [ + "The NSA case highlights the total absence of intelligence debate", + "I think there are two levels of response from the French government.", + "When François Hollande calls Barack Obama or when Foreign Minister Laurent Fabius calls the U.S." + " Ambassador, they respond to a real discovery, which is that of the scale of U.S. surveillance on all" + " communications in France.", + ] + + generated = tokenizer.batch_decode( + hypotheses_batch.tolist(), clean_up_tokenization_spaces=True, skip_special_tokens=True + ) + assert generated == expected_en diff --git a/tests/models/mimi/test_modeling_mimi.py b/tests/models/mimi/test_modeling_mimi.py index dd0f77421be728..137d64279b4695 100644 --- a/tests/models/mimi/test_modeling_mimi.py +++ b/tests/models/mimi/test_modeling_mimi.py @@ -30,6 +30,7 @@ is_flaky, is_torch_available, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_gpu, require_torch_sdpa, @@ -738,10 +739,46 @@ def test_flash_attn_2_inference_equivalence(self): assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + @is_flaky() + def test_flash_attn_3_inference_equivalence(self): + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + outputs = model(dummy_input) + outputs_fa = model_fa(dummy_input) + + logits = outputs[1] + logits_fa = outputs_fa[1] + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + @unittest.skip(reason="The MimiModel does not support right padding") def test_flash_attn_2_inference_equivalence_right_padding(self): pass + @unittest.skip(reason="The MimiModel does not support right padding") + def test_flash_attn_3_inference_equivalence_right_padding(self): + pass + @unittest.skip(reason="The MimiModel does not have support dynamic compile yet") def test_sdpa_can_compile_dynamic(self): pass diff --git a/tests/models/mistral/test_modeling_mistral.py b/tests/models/mistral/test_modeling_mistral.py index 0da7ae72add7bd..c1d088da723369 100644 --- a/tests/models/mistral/test_modeling_mistral.py +++ b/tests/models/mistral/test_modeling_mistral.py @@ -27,6 +27,7 @@ is_flaky, require_bitsandbytes, require_flash_attn, + require_flash_attn_3, require_read_token, require_torch, require_torch_gpu, @@ -439,6 +440,40 @@ def test_flash_attn_2_generate_padding_right(self): dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_padding_right(self): + import torch + + for model_class in self.all_generative_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + + model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + with self.assertRaises(ValueError): + _ = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -484,6 +519,51 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_use_cache(self): + import torch + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: Mistral apparently does not support right padding + use_cache with FA3. + dummy_attention_mask[:, -1] = 1 + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -491,6 +571,13 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest(reason="Mistral flash attention does not support right padding") + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence_right_padding(self): + self.skipTest(reason="Mistral flash attention does not support right padding") + @require_torch_gpu class MistralIntegrationTest(unittest.TestCase): @@ -601,6 +688,31 @@ def test_model_7b_long_prompt(self): generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + @require_flash_attn_3 + @require_bitsandbytes + @slow + @pytest.mark.flash_attn_3_test + def test_model_7b_long_prompt_fa3(self): + EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + # An input with 4097 tokens that is above the size of the sliding window + input_ids = [1] + [306, 338] * 2048 + model = MistralForCausalLM.from_pretrained( + "mistralai/Mistral-7B-v0.1", + device_map={"": torch_device}, + load_in_4bit=True, + attn_implementation="flash_attention_3", + ) + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + # Assisted generation + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + @slow @require_torch_sdpa def test_model_7b_long_prompt_sdpa(self): diff --git a/tests/models/mixtral/test_modeling_mixtral.py b/tests/models/mixtral/test_modeling_mixtral.py index db9641e3dcb2a9..6e074dedbe134e 100644 --- a/tests/models/mixtral/test_modeling_mixtral.py +++ b/tests/models/mixtral/test_modeling_mixtral.py @@ -23,6 +23,7 @@ from transformers.testing_utils import ( is_flaky, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_gpu, require_torch_sdpa, @@ -440,6 +441,40 @@ def test_flash_attn_2_generate_padding_right(self): dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_padding_right(self): + import torch + + for model_class in self.all_generative_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + + model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + with self.assertRaises(ValueError): + _ = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -485,6 +520,51 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_use_cache(self): + import torch + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: Mixtral apparently does not support right padding + use_cache with FA3. + dummy_attention_mask[:, -1] = 1 + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -492,6 +572,13 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest(reason="Mixtral flash attention does not support right padding") + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence_right_padding(self): + self.skipTest(reason="Mixtral flash attention does not support right padding") + # Ignore copy def test_load_balancing_loss(self): r""" diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index e143b8ac3c8658..8caf109bdd01e6 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -35,6 +35,7 @@ from transformers.testing_utils import ( is_torch_available, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_accelerator, require_torch_fp16, @@ -404,6 +405,86 @@ def test_flash_attn_2_inference_equivalence(self): model.train() _ = model_fa(dummy_input, **other_inputs) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence + def test_flash_attn_3_inference_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + + # Ignore copy + outputs = model(dummy_input, output_hidden_states=True) + # Ignore copy + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + # Ignore copy + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + _ = model_fa(dummy_input, **other_inputs) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -483,6 +564,85 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence_right_padding + def test_flash_attn_3_inference_equivalence_right_padding(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + if model.config.is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input) + + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + else: + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + # Ignore copy + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -529,6 +689,52 @@ def test_flash_attn_2_generate_left_padding(self): self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_left_padding + def test_flash_attn_3_generate_left_padding(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # make sure we do left padding + dummy_attention_mask[:, :-1] = 0 + dummy_attention_mask[:, -1:] = 1 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -575,6 +781,52 @@ def test_flash_attn_2_generate_padding_right(self): self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_padding_right + def test_flash_attn_3_generate_padding_right(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # make sure we do right padding + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -621,6 +873,52 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_use_cache + def test_flash_attn_3_generate_use_cache(self): + max_new_tokens = 30 + + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa @slow @@ -1599,36 +1897,121 @@ def test_greedy_generate_stereo_outputs(self): config, input_ids, attention_mask = self._get_input_ids_and_config() config.audio_channels = 2 - model = model_class(config).to(torch_device).eval() - output_generate = self._greedy_generate( - model=model, - input_ids=input_ids.to(torch_device), - attention_mask=attention_mask.to(torch_device), - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) + model = model_class(config).to(torch_device).eval() + output_generate = self._greedy_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) + + self.assertNotIn(config.pad_token_id, output_generate) + + @unittest.skip( + reason="MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model" + ) + def test_save_load_fast_init_from_base(self): + pass + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence + def test_flash_attn_2_inference_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + + # Ignore copy + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input) + # Ignore copy + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + # Ignore copy + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + # Ignore copy + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + # Ignore copy + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + # Ignore copy + outputs = model(dummy_input, **other_inputs) + # Ignore copy + outputs_fa = model_fa(dummy_input, **other_inputs) - self.assertIsInstance(output_generate, GenerateEncoderDecoderOutput) + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) - self.assertNotIn(config.pad_token_id, output_generate) + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) - @unittest.skip( - reason="MusicgenModel is actually not the base of MusicgenForCausalLM as the latter is a composit model" - ) - def test_save_load_fast_init_from_base(self): - pass + # check with inference + dropout + model.train() + _ = model_fa(dummy_input, **other_inputs) - @require_flash_attn + @require_flash_attn_3 @require_torch_gpu - @mark.flash_attn_test + @mark.flash_attn_3_test @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence - def test_flash_attn_2_inference_equivalence(self): + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence + def test_flash_attn_3_inference_equivalence(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -1636,7 +2019,7 @@ def test_flash_attn_2_inference_equivalence(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" ) model_fa.to(torch_device) @@ -1787,6 +2170,88 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence_right_padding + def test_flash_attn_3_inference_equivalence_right_padding(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + # Ignore copy + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input) + # Ignore copy + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + # Ignore copy + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + # Ignore copy + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + # Ignore copy + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + # Ignore copy + outputs = model(dummy_input, **other_inputs) + # Ignore copy + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -1836,6 +2301,55 @@ def test_flash_attn_2_generate_left_padding(self): self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_left_padding + def test_flash_attn_3_generate_left_padding(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask") + if dummy_attention_mask is None: + dummy_attention_mask = torch.ones_like(dummy_input) + + # make sure we do left padding + dummy_attention_mask[:, :-1] = 0 + dummy_attention_mask[:, -1:] = 1 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -1884,6 +2398,54 @@ def test_flash_attn_2_generate_padding_right(self): self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_padding_right + def test_flash_attn_3_generate_padding_right(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask") + if dummy_attention_mask is None: + dummy_attention_mask = torch.ones_like(dummy_input) + # make sure we do right padding + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -1930,6 +2492,52 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_use_cache + def test_flash_attn_3_generate_use_cache(self): + max_new_tokens = 30 + + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa @slow diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index 28c2bf2f168ba9..138f9640902541 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -35,6 +35,7 @@ is_torch_available, is_torchaudio_available, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_accelerator, require_torch_fp16, @@ -406,6 +407,86 @@ def test_flash_attn_2_inference_equivalence(self): model.train() _ = model_fa(dummy_input, **other_inputs) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_3_inference_equivalence + def test_flash_attn_3_inference_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + + # Ignore copy + outputs = model(dummy_input, output_hidden_states=True) + # Ignore copy + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + # Ignore copy + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + _ = model_fa(dummy_input, **other_inputs) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -485,6 +566,85 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_3_inference_equivalence_right_padding + def test_flash_attn_3_inference_equivalence_right_padding(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + if model.config.is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input) + + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + else: + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + # Ignore copy + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -531,6 +691,52 @@ def test_flash_attn_2_generate_left_padding(self): self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_left_padding + def test_flash_attn_3_generate_left_padding(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # make sure we do left padding + dummy_attention_mask[:, :-1] = 0 + dummy_attention_mask[:, -1:] = 1 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -577,6 +783,52 @@ def test_flash_attn_2_generate_padding_right(self): self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_padding_right + def test_flash_attn_3_generate_padding_right(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # make sure we do right padding + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -623,6 +875,52 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.models.musicgen.test_modeling_musicgen.MusicgenDecoderTest.test_flash_attn_3_generate_use_cache + def test_flash_attn_3_generate_use_cache(self): + max_new_tokens = 30 + + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa @slow @@ -1583,36 +1881,121 @@ def test_greedy_generate_stereo_outputs(self): config, input_ids, attention_mask = self._get_input_ids_and_config() config.audio_channels = 2 - model = model_class(config).to(torch_device).eval() - output_generate = self._greedy_generate( - model=model, - input_ids=input_ids.to(torch_device), - attention_mask=attention_mask.to(torch_device), - output_scores=True, - output_hidden_states=True, - output_attentions=True, - return_dict_in_generate=True, - ) + model = model_class(config).to(torch_device).eval() + output_generate = self._greedy_generate( + model=model, + input_ids=input_ids.to(torch_device), + attention_mask=attention_mask.to(torch_device), + output_scores=True, + output_hidden_states=True, + output_attentions=True, + return_dict_in_generate=True, + ) + + self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) + + self.assertNotIn(config.pad_token_id, output_generate) + + @unittest.skip( + reason="MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model" + ) + def test_save_load_fast_init_from_base(self): + pass + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence + def test_flash_attn_2_inference_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + + # Ignore copy + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input) + # Ignore copy + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + # Ignore copy + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + # Ignore copy + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + # Ignore copy + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + # Ignore copy + outputs = model(dummy_input, **other_inputs) + # Ignore copy + outputs_fa = model_fa(dummy_input, **other_inputs) - self.assertIsInstance(output_generate, GenerateDecoderOnlyOutput) + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) - self.assertNotIn(config.pad_token_id, output_generate) + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) - @unittest.skip( - reason="MusicgenMelodyModel is actually not the base of MusicgenMelodyForCausalLM as the latter is a composit model" - ) - def test_save_load_fast_init_from_base(self): - pass + # check with inference + dropout + model.train() + _ = model_fa(dummy_input, **other_inputs) - @require_flash_attn + @require_flash_attn_3 @require_torch_gpu - @mark.flash_attn_test + @mark.flash_attn_3_test @slow - # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_2_inference_equivalence - def test_flash_attn_2_inference_equivalence(self): + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence + def test_flash_attn_3_inference_equivalence(self): for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -1620,7 +2003,7 @@ def test_flash_attn_2_inference_equivalence(self): with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) model_fa = model_class.from_pretrained( - tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" ) model_fa.to(torch_device) @@ -1771,6 +2154,88 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_inference_equivalence_right_padding + def test_flash_attn_3_inference_equivalence_right_padding(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + # Ignore copy + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + # Ignore copy + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + # Ignore copy + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input) + # Ignore copy + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + # Ignore copy + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + # Ignore copy + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + # Ignore copy + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + # Ignore copy + outputs = model(dummy_input, **other_inputs) + # Ignore copy + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -1820,6 +2285,55 @@ def test_flash_attn_2_generate_left_padding(self): self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_left_padding + def test_flash_attn_3_generate_left_padding(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask") + if dummy_attention_mask is None: + dummy_attention_mask = torch.ones_like(dummy_input) + + # make sure we do left padding + dummy_attention_mask[:, :-1] = 0 + dummy_attention_mask[:, -1:] = 1 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -1868,6 +2382,54 @@ def test_flash_attn_2_generate_padding_right(self): self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_padding_right + def test_flash_attn_3_generate_padding_right(self): + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask") + if dummy_attention_mask is None: + dummy_attention_mask = torch.ones_like(dummy_input) + # make sure we do right padding + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=8, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -1914,6 +2476,52 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + # Copied from tests.test_modeling_common.ModelTesterMixin.test_flash_attn_3_generate_use_cache + def test_flash_attn_3_generate_use_cache(self): + max_new_tokens = 30 + + # Ignore copy + for model_class in self.greedy_sample_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa @slow diff --git a/tests/models/nemotron/test_modeling_nemotron.py b/tests/models/nemotron/test_modeling_nemotron.py index 4f8f4cc77fe8d0..450939bd0d36a1 100644 --- a/tests/models/nemotron/test_modeling_nemotron.py +++ b/tests/models/nemotron/test_modeling_nemotron.py @@ -25,6 +25,7 @@ from transformers.testing_utils import ( is_flaky, require_flash_attn, + require_flash_attn_3, require_read_token, require_torch, require_torch_gpu, @@ -178,6 +179,40 @@ def test_flash_attn_2_equivalence(self): # nemotron flash attention 2 needs a high tolerance assert torch.allclose(logits_fa, logits, atol=1e-2) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @is_flaky() + @slow + def test_flash_attn_3_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(reason="Model does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, attn_implementation="eager") + model.to(torch_device) + + dummy_input = inputs_dict[model_class.main_input_name] + dummy_input = dummy_input.to(torch_device) + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = outputs.hidden_states[-1] + logits_fa = outputs_fa.hidden_states[-1] + + # nemotron flash attention 3 needs a high tolerance + assert torch.allclose(logits_fa, logits, atol=1e-2) + @require_torch_gpu class NemotronIntegrationTest(unittest.TestCase): diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index 95b0b01c0a23d9..fca9c563361edd 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -24,6 +24,7 @@ from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_gpu, slow, @@ -498,6 +499,43 @@ def test_flash_attn_2_generate_padding_right(self): self.assertListEqual(output_native, output_fa_2) + @require_flash_attn_3 + @require_torch_gpu + @require_bitsandbytes + @pytest.mark.flash_attn_3_test + @slow + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_flash_attn_3_generate_padding_right with LlamaForCausalLM->PhiForCausalLM,LlamaTokenizer->AutoTokenizer,meta-llama/Llama-2-7b-hf->microsoft/phi-1 + def test_flash_attn_3_generate_padding_right(self): + """ + Overwritting the common test as the test is flaky on tiny models + """ + model = PhiForCausalLM.from_pretrained( + "microsoft/phi-1", + load_in_4bit=True, + device_map={"": 0}, + ) + + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1") + + texts = ["hi", "Hello this is a very long sentence"] + + tokenizer.padding_side = "right" + tokenizer.pad_token = tokenizer.eos_token + + inputs = tokenizer(texts, return_tensors="pt", padding=True).to(0) + + output_native = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_native = tokenizer.batch_decode(output_native) + + model = PhiForCausalLM.from_pretrained( + "microsoft/phi-1", load_in_4bit=True, device_map={"": 0}, attn_implementation="flash_attention_3" + ) + + output_fa_3 = model.generate(**inputs, max_new_tokens=20, do_sample=False) + output_fa_3 = tokenizer.batch_decode(output_fa_3) + + self.assertListEqual(output_native, output_fa_3) + @slow @require_torch diff --git a/tests/models/qwen2/test_modeling_qwen2.py b/tests/models/qwen2/test_modeling_qwen2.py index 4d6c432f20424d..59745222cec592 100644 --- a/tests/models/qwen2/test_modeling_qwen2.py +++ b/tests/models/qwen2/test_modeling_qwen2.py @@ -25,6 +25,7 @@ backend_empty_cache, require_bitsandbytes, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_gpu, require_torch_sdpa, @@ -450,6 +451,40 @@ def test_flash_attn_2_generate_padding_right(self): dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_padding_right(self): + import torch + + for model_class in self.all_generative_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + + model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + with self.assertRaises(ValueError): + _ = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -495,6 +530,51 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_use_cache(self): + import torch + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: Qwen2 apparently does not support right padding + use_cache with FA3. + dummy_attention_mask[:, -1] = 1 + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -502,6 +582,13 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest(reason="Qwen2 flash attention does not support right padding") + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence_right_padding(self): + self.skipTest(reason="Qwen2 flash attention does not support right padding") + @require_torch class Qwen2IntegrationTest(unittest.TestCase): @@ -571,6 +658,36 @@ def test_model_450m_long_prompt(self): backend_empty_cache(torch_device) gc.collect() + @require_bitsandbytes + @slow + @require_flash_attn_3 + @pytest.mark.flash_attn_3_test + def test_model_450m_long_prompt_fav3(self): + EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + # An input with 4097 tokens that is above the size of the sliding window + input_ids = [1] + [306, 338] * 2048 + model = Qwen2ForCausalLM.from_pretrained( + "Qwen/Qwen2-450m-beta", + device_map="auto", + load_in_4bit=True, + attn_implementation="flash_attention_3", + ) + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + # Assisted generation + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + del assistant_model + del model + backend_empty_cache(torch_device) + gc.collect() + @slow @require_torch_sdpa def test_model_450m_long_prompt_sdpa(self): diff --git a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py index 0425172a6fba4d..15eb85b5bcc351 100644 --- a/tests/models/qwen2_moe/test_modeling_qwen2_moe.py +++ b/tests/models/qwen2_moe/test_modeling_qwen2_moe.py @@ -25,6 +25,7 @@ backend_empty_cache, require_bitsandbytes, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_gpu, require_torch_sdpa, @@ -475,6 +476,40 @@ def test_flash_attn_2_generate_padding_right(self): dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_padding_right(self): + import torch + + for model_class in self.all_generative_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + + model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + with self.assertRaises(ValueError): + _ = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -520,6 +555,51 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_use_cache(self): + import torch + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: Qwen2Moe apparently does not support right padding + use_cache with FA3. + dummy_attention_mask[:, -1] = 1 + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -527,6 +607,13 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest(reason="Qwen2Moe flash attention does not support right padding") + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence_right_padding(self): + self.skipTest(reason="Qwen2Moe flash attention does not support right padding") + # Ignore copy def test_load_balancing_loss(self): r""" @@ -633,6 +720,36 @@ def test_model_a2_7b_long_prompt(self): backend_empty_cache(torch_device) gc.collect() + @require_bitsandbytes + @slow + @require_flash_attn_3 + @pytest.mark.flash_attn_3_test + def test_model_a2_7b_long_prompt_fav3(self): + EXPECTED_OUTPUT_TOKEN_IDS = [306, 338] + # An input with 4097 tokens that is above the size of the sliding window + input_ids = [1] + [306, 338] * 2048 + model = Qwen2MoeForCausalLM.from_pretrained( + "Qwen/Qwen1.5-MoE-A2.7B", + device_map="auto", + load_in_4bit=True, + attn_implementation="flash_attention_3", + ) + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + # Assisted generation + assistant_model = model + assistant_model.generation_config.num_assistant_tokens = 2 + assistant_model.generation_config.num_assistant_tokens_schedule = "constant" + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-2:].tolist()) + + del assistant_model + del model + backend_empty_cache(torch_device) + gc.collect() + @slow @require_torch_sdpa def test_model_a2_7b_long_prompt_sdpa(self): diff --git a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py index 956243dccebebf..312cd6df4a9d57 100644 --- a/tests/models/qwen2_vl/test_modeling_qwen2_vl.py +++ b/tests/models/qwen2_vl/test_modeling_qwen2_vl.py @@ -28,6 +28,7 @@ ) from transformers.testing_utils import ( require_flash_attn, + require_flash_attn_3, require_torch, require_torch_gpu, slow, @@ -478,6 +479,38 @@ def test_small_model_integration_test_batch_flashatt2(self): self.processor.batch_decode(output, skip_special_tokens=True)[1], ) + @slow + @require_flash_attn_3 + @require_torch_gpu + def test_small_model_integration_test_batch_flashatt3(self): + model = Qwen2VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2-VL-7B-Instruct", + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_3", + device_map="auto", + ) + text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) + inputs = self.processor(text=[text, text], images=[self.image, self.image], return_tensors="pt").to( + torch_device + ) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30) + + EXPECTED_DECODED_TEXT = [ + "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets", + "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets", + ] + + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True)[0], + self.processor.batch_decode(output, skip_special_tokens=True)[1], + ) + @slow @require_flash_attn @require_torch_gpu @@ -510,3 +543,36 @@ def test_small_model_integration_test_batch_wo_image_flashatt2(self): self.processor.batch_decode(output, skip_special_tokens=True), EXPECTED_DECODED_TEXT, ) + + @slow + @require_flash_attn_3 + @require_torch_gpu + def test_small_model_integration_test_batch_wo_image_flashatt3(self): + model = Qwen2VLForConditionalGeneration.from_pretrained( + "Qwen/Qwen2-VL-7B-Instruct", + torch_dtype=torch.bfloat16, + attn_implementation="flash_attention_3", + device_map="auto", + ) + text = self.processor.apply_chat_template(self.messages, tokenize=False, add_generation_prompt=True) + messages2 = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Who are you?"}, + ] + text2 = self.processor.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True) + inputs = self.processor(text=[text, text2], images=[self.image], padding=True, return_tensors="pt").to( + torch_device + ) + + # it should not matter whether two images are the same size or not + output = model.generate(**inputs, max_new_tokens=30) + + EXPECTED_DECODED_TEXT = [ + "system\nYou are a helpful assistant.\nuser\nWhat kind of dog is this?\nassistant\nThe dog in the picture appears to be a Labrador Retriever. Labradors are known for their friendly and intelligent nature, making them popular pets", + "system\nYou are a helpful assistant.\nuser\nWho are you?\nassistant\nI am Qwen, a large language model created by Alibaba Cloud. I am designed to answer a wide range of questions and provide information on various topics", + ] + + self.assertEqual( + self.processor.batch_decode(output, skip_special_tokens=True), + EXPECTED_DECODED_TEXT, + ) diff --git a/tests/models/siglip/test_modeling_siglip.py b/tests/models/siglip/test_modeling_siglip.py index 9d1e3109b313c3..a3c1cfcb66e40f 100644 --- a/tests/models/siglip/test_modeling_siglip.py +++ b/tests/models/siglip/test_modeling_siglip.py @@ -28,6 +28,7 @@ from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig from transformers.testing_utils import ( require_flash_attn, + require_flash_attn_3, require_torch, require_torch_gpu, require_torch_sdpa, @@ -834,12 +835,95 @@ def test_flash_attn_2_inference_equivalence(self): output_hidden_states=True, ) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence(self): + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + dummy_pixel_values = inputs_dict["pixel_values"].to(torch.bfloat16) + dummy_input_ids = inputs_dict["input_ids"] + + outputs = model(pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True) + outputs_fa = model_fa( + pixel_values=dummy_pixel_values, input_ids=dummy_input_ids, output_hidden_states=True + ) + + self.assertTrue( + torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2), + f"Image logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}", + ) + self.assertTrue( + torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2), + f"Text logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}", + ) + + # Test with attention mask + dummy_attention_mask = inputs_dict["attention_mask"] + + if dummy_attention_mask is not None: + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + + outputs = model( + pixel_values=dummy_pixel_values, + input_ids=dummy_input_ids, + attention_mask=dummy_attention_mask, + output_hidden_states=True, + ) + outputs_fa = model_fa( + pixel_values=dummy_pixel_values, + input_ids=dummy_input_ids, + attention_mask=dummy_attention_mask, + output_hidden_states=True, + ) + + self.assertTrue( + torch.allclose(outputs.logits_per_image, outputs_fa.logits_per_image, atol=4e-2, rtol=4e-2), + f"Logits max diff: {torch.max(torch.abs(outputs.logits_per_image - outputs_fa.logits_per_image))}", + ) + self.assertTrue( + torch.allclose(outputs.logits_per_text, outputs_fa.logits_per_text, atol=4e-2, rtol=4e-2), + f"Logits max diff: {torch.max(torch.abs(outputs.logits_per_text - outputs_fa.logits_per_text))}", + ) + + # check with inference + dropout + model.train() + _ = model_fa( + pixel_values=dummy_pixel_values, + input_ids=dummy_input_ids, + attention_mask=dummy_attention_mask, + output_hidden_states=True, + ) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest("SigLIP does not support right padding") + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + def test_flash_attn_3_inference_equivalence_right_padding(self): + self.skipTest("SigLIP does not support right padding") + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @require_torch_sdpa @slow diff --git a/tests/models/stablelm/test_modeling_stablelm.py b/tests/models/stablelm/test_modeling_stablelm.py index 36cad89bcfdf06..5437330bc96a59 100644 --- a/tests/models/stablelm/test_modeling_stablelm.py +++ b/tests/models/stablelm/test_modeling_stablelm.py @@ -24,6 +24,7 @@ is_flaky, require_bitsandbytes, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_sdpa, slow, @@ -559,6 +560,24 @@ def test_model_3b_long_prompt(self): generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-3:].tolist()) + @require_bitsandbytes + @slow + @require_flash_attn_3 + @pytest.mark.flash_attn_3_test + def test_model_3b_long_prompt_fav3(self): + EXPECTED_OUTPUT_TOKEN_IDS = [3, 3, 3] + input_ids = [306, 338] * 2047 + model = StableLmForCausalLM.from_pretrained( + "stabilityai/stablelm-3b-4e1t", + device_map="auto", + torch_dtype="auto", + load_in_4bit=True, + attn_implementation="flash_attention_3", + ) + input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device) + generated_ids = model.generate(input_ids, max_new_tokens=4, temperature=0) + self.assertEqual(EXPECTED_OUTPUT_TOKEN_IDS, generated_ids[0][-3:].tolist()) + # Copied from transformers.tests.models.llama.test_modeling_llama.LlamaModelTest.test_eager_matches_sdpa_generate with Llama->StableLm,saibo/llama-1B->stabilityai/stablelm-3b-4e1t # TODO: @Fxmarty @is_flaky(max_attempts=3, description="flaky on some models.") diff --git a/tests/models/starcoder2/test_modeling_starcoder2.py b/tests/models/starcoder2/test_modeling_starcoder2.py index c1c7d45d4f18d7..04a5f657f833b0 100644 --- a/tests/models/starcoder2/test_modeling_starcoder2.py +++ b/tests/models/starcoder2/test_modeling_starcoder2.py @@ -23,6 +23,7 @@ from transformers.testing_utils import ( require_bitsandbytes, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_gpu, slow, @@ -431,6 +432,40 @@ def test_flash_attn_2_generate_padding_right(self): dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_padding_right(self): + import torch + + for model_class in self.all_generative_model_classes: + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0]]).to(torch_device) + + model.generate(dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + with self.assertRaises(ValueError): + _ = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -476,6 +511,51 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_use_cache(self): + import torch + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # NOTE: Starcoder2 apparently does not support right padding + use_cache with FA3. + dummy_attention_mask[:, -1] = 1 + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -483,6 +563,13 @@ def test_flash_attn_2_generate_use_cache(self): def test_flash_attn_2_inference_equivalence_right_padding(self): self.skipTest(reason="Starcoder2 flash attention does not support right padding") + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence_right_padding(self): + self.skipTest(reason="Starcoder2 flash attention does not support right padding") + @slow @require_torch_gpu @@ -549,6 +636,28 @@ def test_starcoder2_batched_generation_fa2(self): output_text = tokenizer.batch_decode(output, skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT, output_text) + @require_flash_attn_3 + @pytest.mark.flash_attn_3_test + def test_starcoder2_batched_generation_fa3(self): + EXPECTED_TEXT = [ + "Hello my name is Younes and I am a student at the University of Liverpool. I am currently studying for my MSc in Computer Science. I am interested in the field of Machine Learning and I am currently working on", + "def hello_world():\n\treturn 'Hello World!'\n\n@app.route('/hello/')\ndef hello_name(name):\n\treturn 'Hello %s!' % name\n\n@app", + ] + model_id = "bigcode/starcoder2-7b" + + model = Starcoder2ForCausalLM.from_pretrained( + model_id, torch_dtype=torch.float16, device_map="auto", attn_implementation="flash_attention_3" + ) + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = tokenizer.eos_token + + text = ["Hello my name is Younes and", "def hello_world():"] + inputs = tokenizer(text, return_tensors="pt", padding=True).to(torch_device) + + output = model.generate(**inputs, max_new_tokens=40, do_sample=False) + output_text = tokenizer.batch_decode(output, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT, output_text) + @require_bitsandbytes def test_starcoder2_batched_generation_4bit(self): EXPECTED_TEXT = [ diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index ff7a85218d3a00..1bce68348cc91b 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -35,6 +35,7 @@ is_pyctcdecode_available, is_torchaudio_available, require_flash_attn, + require_flash_attn_3, require_pyctcdecode, require_soundfile, require_torch, @@ -2023,6 +2024,28 @@ def test_inference_ctc_fa2(self): EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"] self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + def test_inference_ctc_fa3(self): + model_fa = Wav2Vec2ForCTC.from_pretrained( + "facebook/wav2vec2-base-960h", attn_implementation="flash_attention_3", torch_dtype=torch.bfloat16 + ) + model_fa.to(torch_device) + processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True) + input_speech = self._load_datasamples(1) + + input_values = processor(input_speech, return_tensors="pt").input_values.to(torch_device) + + with torch.no_grad(): + logits = model_fa(input_values.to(torch.bfloat16)).logits + + predicted_ids = torch.argmax(logits, dim=-1) + predicted_trans = processor.batch_decode(predicted_ids) + + EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"] + self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -2049,3 +2072,30 @@ def test_inference_ctc_fa2_batched(self): "sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore", ] self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) + + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + def test_inference_ctc_fa3_batched(self): + model_fa = Wav2Vec2ForCTC.from_pretrained( + "facebook/wav2vec2-base-960h", attn_implementation="flash_attention_3", torch_dtype=torch.bfloat16 + ) + model_fa.to(torch_device) + processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True) + + input_speech = self._load_datasamples(2) + + inputs = processor(input_speech, return_tensors="pt", padding=True, return_attention_mask=True) + inputs = inputs.to(torch_device) + + with torch.no_grad(): + logits = model_fa(inputs.input_values.to(torch.bfloat16), attention_mask=inputs.attention_mask).logits + + predicted_ids = torch.argmax(logits, dim=-1) + predicted_trans = processor.batch_decode(predicted_ids) + + EXPECTED_TRANSCRIPTIONS = [ + "a man said to the universe sir i exist", + "sweat covered brion's body trickling into the tight lowing cloth that was the only garment he wore", + ] + self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS) diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index 70b38d3bf38170..1b2ccd7505bce9 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -34,6 +34,7 @@ is_flaky, is_pt_flax_cross_test, require_flash_attn, + require_flash_attn_3, require_torch, require_torch_fp16, require_torch_gpu, @@ -974,6 +975,52 @@ def test_flash_attn_2_inference_equivalence(self): model.train() _ = model_fa(dummy_input, decoder_input_ids=decoder_input_ids) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence(self): + import torch + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(reason="Model does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.bfloat16, + ) + model.to(torch_device) + + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] + + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + + logits = outputs.decoder_hidden_states[-1] + logits_fa = outputs_fa.decoder_hidden_states[-1] + + # whisper FA3 needs very high tolerance + assert torch.allclose(logits_fa, logits, atol=4e-1) + + # check with inference + dropout + model.train() + _ = model_fa(dummy_input, decoder_input_ids=decoder_input_ids) + @require_flash_attn @require_torch_gpu @pytest.mark.flash_attn_test @@ -1030,6 +1077,62 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): # whisper FA2 needs very high tolerance assert torch.allclose(logits_fa[:, -2:], logits[:, -2:], atol=4e-1) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_inference_equivalence_right_padding(self): + import torch + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(reason="Model does not support flash_attention_3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16) + model.to(torch_device) + + dummy_input = inputs_dict[model.main_input_name][:1] + dummy_input = dummy_input.to(torch.float16) + + decoder_input_ids = torch.tensor([[0, 1, 2, 3, 4, 5]], device=dummy_input.device, dtype=torch.long) + decoder_attention_mask = torch.tensor( + [[0, 0, 0, 1, 1, 1]], device=dummy_input.device, dtype=torch.long + ) + + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + + logits = outputs.decoder_hidden_states[-1] + logits_fa = outputs_fa.decoder_hidden_states[-1] + + # whisper FA3 needs very high tolerance + assert torch.allclose(logits_fa, logits, atol=4e-1) + + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": decoder_attention_mask, + "output_hidden_states": True, + } + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = outputs.decoder_hidden_states[-1] + logits_fa = outputs_fa.decoder_hidden_states[-1] + + # whisper FA3 needs very high tolerance + assert torch.allclose(logits_fa[:, -2:], logits[:, -2:], atol=4e-1) + def _create_and_check_torchscript(self, config, inputs_dict): if not self.test_torchscript: self.skipTest(reason="test_torchscript is set to False") @@ -1682,6 +1785,59 @@ def test_flash_attn_2_generate_reuse_cache(self): past_key_values=past_key_values, ) + @require_flash_attn_3 + @require_torch_gpu + @pytest.mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_reuse_cache(self): + max_new_tokens = 2 + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name][..., :10] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # run generate once to get filled cache + output = model.generate( + dummy_input, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + return_dict_in_generate=True, + ) + past_key_values = output.past_key_values + + # Try to continue generation from where we left, given that we have more than 1 new token to process + # e.g. this can happen in speculative decoding when feeding candidate tokens back to target model + _ = model.generate( + dummy_input, + decoder_input_ids=output.sequences, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + past_key_values=past_key_values, + ) + def test_labels_sequence_max_length_correct(self): config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1ad6e93b10ff84..0077425606efe5 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -76,6 +76,7 @@ require_accelerate, require_bitsandbytes, require_flash_attn, + require_flash_attn_3, require_non_xpu, require_read_token, require_safetensors, @@ -3487,6 +3488,34 @@ def test_flash_attn_2_conversion(self): self.assertTrue(False, "FlashAttention2 modules not found in model") + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + def test_flash_attn_3_conversion(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.float16, attn_implementation="flash_attention_3" + ).to(torch_device) + + for _, module in model.named_modules(): + if "FlashAttention" in module.__class__.__name__: + return + + self.assertTrue(False, "FlashAttention3 modules not found in model") + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -3584,6 +3613,103 @@ def test_flash_attn_2_inference_equivalence(self): model.train() _ = model_fa(dummy_input, **other_inputs) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + @is_flaky() + def test_flash_attn_3_inference_equivalence(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + dummy_attention_mask = dummy_attention_mask[:1] + dummy_attention_mask[:, 1:] = 1 + dummy_attention_mask[:, :1] = 0 + + if model.config.is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] + + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + else: + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + if model.config.is_encoder_decoder: + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + else: + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[1:], logits[1:], atol=4e-2, rtol=4e-2) + + # check with inference + dropout + model.train() + _ = model_fa(dummy_input, **other_inputs) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -3677,6 +3803,99 @@ def test_flash_attn_2_inference_equivalence_right_padding(self): assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + @is_flaky() + def test_flash_attn_3_inference_equivalence_right_padding(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + for model_class in self.all_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_fa = model_class.from_pretrained( + tmpdirname, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_3" + ) + model_fa.to(torch_device) + + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.bfloat16) + model.to(torch_device) + + dummy_input = inputs_dict[model.main_input_name][:1] + if dummy_input.dtype in [torch.float32, torch.float16]: + dummy_input = dummy_input.to(torch.bfloat16) + + dummy_attention_mask = inputs_dict.get("attention_mask", None) + + if dummy_attention_mask is not None: + dummy_attention_mask = dummy_attention_mask[:1] + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + if model.config.is_encoder_decoder: + decoder_input_ids = inputs_dict.get("decoder_input_ids", dummy_input)[:1] + + outputs = model(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, decoder_input_ids=decoder_input_ids, output_hidden_states=True) + else: + outputs = model(dummy_input, output_hidden_states=True) + outputs_fa = model_fa(dummy_input, output_hidden_states=True) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa, logits, atol=4e-2, rtol=4e-2) + + if model.config.is_encoder_decoder: + other_inputs = { + "decoder_input_ids": decoder_input_ids, + "decoder_attention_mask": dummy_attention_mask, + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + else: + other_inputs = { + "output_hidden_states": True, + } + if dummy_attention_mask is not None: + other_inputs["attention_mask"] = dummy_attention_mask + + outputs = model(dummy_input, **other_inputs) + outputs_fa = model_fa(dummy_input, **other_inputs) + + logits = ( + outputs.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs.decoder_hidden_states[-1] + ) + logits_fa = ( + outputs_fa.hidden_states[-1] + if not model.config.is_encoder_decoder + else outputs_fa.decoder_hidden_states[-1] + ) + + assert torch.allclose(logits_fa[:-1], logits[:-1], atol=4e-2, rtol=4e-2) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -3725,18 +3944,114 @@ def test_flash_attn_2_generate_left_padding(self): self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + @is_flaky() + def test_flash_attn_3_generate_left_padding(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # make sure we do left padding + dummy_attention_mask[:, :-1] = 0 + dummy_attention_mask[:, -1:] = 1 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @is_flaky() @slow - def test_flash_attn_2_generate_padding_right(self): + def test_flash_attn_2_generate_padding_right(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model = model_class.from_pretrained(tmpdirname, torch_dtype=torch.float16, low_cpu_mem_usage=True).to( + torch_device + ) + + dummy_input = inputs_dict[model.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + # make sure we do right padding + dummy_attention_mask[:, :-1] = 1 + dummy_attention_mask[:, -1:] = 0 + + out = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, + ).to(torch_device) + + out_fa = model.generate( + dummy_input, attention_mask=dummy_attention_mask, max_new_tokens=1, do_sample=False + ) + + self.assertTrue(torch.allclose(out, out_fa)) + + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @is_flaky() + @slow + def test_flash_attn_3_generate_padding_right(self): if not self.has_attentions: self.skipTest(reason="Model architecture does not support attentions") for model_class in self.all_generative_model_classes: if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() model = model_class(config) @@ -3763,7 +4078,7 @@ def test_flash_attn_2_generate_padding_right(self): model = model_class.from_pretrained( tmpdirname, torch_dtype=torch.float16, - attn_implementation="flash_attention_2", + attn_implementation="flash_attention_3", low_cpu_mem_usage=True, ).to(torch_device) @@ -4352,6 +4667,65 @@ def test_flash_attn_2_generate_use_cache(self): use_cache=True, ) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_use_cache(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_2: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # Just test that a large cache works as expected + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + + # Generate with one batch only to test generation when attention mask will be None + # when real inputs are used, because there is no padding. See issue #32237 for more + dummy_input = dummy_input[:1, ...] + dummy_attention_mask = torch.ones_like(dummy_attention_mask[:1, ...]) + _ = model.generate( + dummy_input, + attention_mask=dummy_attention_mask, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + ) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -4408,6 +4782,62 @@ def test_flash_attn_2_generate_reuse_cache(self): past_key_values=past_key_values, ) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + def test_flash_attn_3_generate_reuse_cache(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + max_new_tokens = 2 + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = dummy_input.shape[1] * 2 + max_new_tokens * 2 + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ).to(torch_device) + + # run generate once to get filled cache + output = model.generate( + dummy_input, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + return_dict_in_generate=True, + ) + past_key_values = output.past_key_values + + # Try to continue generation from where we left, given that we have more than 1 new token to process + # e.g. this can happen in speculative decoding when feeding candidate tokens back to target model + dummy_input_updated = torch.cat([dummy_input, output.sequences], dim=-1) + _ = model.generate( + dummy_input_updated, + max_new_tokens=max_new_tokens, + do_sample=False, + use_cache=True, + past_key_values=past_key_values, + ) + @require_flash_attn @require_torch_gpu @require_bitsandbytes @@ -4465,6 +4895,63 @@ def test_flash_attn_2_fp32_ln(self): # with attention mask _ = model(dummy_input, attention_mask=dummy_attention_mask) + @require_flash_attn_3 + @require_torch_gpu + @require_bitsandbytes + @mark.flash_attn_3_test + @slow + def test_flash_attn_3_fp32_ln(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + model = model_class(config) + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + dummy_input = inputs_dict[model.main_input_name] + dummy_attention_mask = inputs_dict.get("attention_mask", torch.ones_like(dummy_input)) + batch_size = dummy_attention_mask.shape[0] + + is_padding_right = dummy_attention_mask[:, -1].sum().item() != batch_size + + # To avoid errors with padding_side=="right" + if is_padding_right: + dummy_attention_mask = torch.ones_like(dummy_input) + + model = model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + load_in_4bit=True, + ) + + for _, param in model.named_parameters(): + # upcast only layer norms + if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16): + param.data = param.data.to(torch.float32) + + if model.config.is_encoder_decoder: + dummy_decoder_input_ids = inputs_dict["decoder_input_ids"] + dummy_decoder_attention_mask = inputs_dict["decoder_attention_mask"] + + _ = model(dummy_input, decoder_input_ids=dummy_decoder_input_ids) + # with attention mask + _ = model( + dummy_input, + attention_mask=dummy_attention_mask, + decoder_input_ids=dummy_decoder_input_ids, + decoder_attention_mask=dummy_decoder_attention_mask, + ) + else: + _ = model(dummy_input) + # with attention mask + _ = model(dummy_input, attention_mask=dummy_attention_mask) + @require_flash_attn @require_torch_gpu @mark.flash_attn_test @@ -4538,6 +5025,79 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): tol = torch.finfo(torch.float16).eps torch.testing.assert_close(logits_padded, logits_padfree, atol=tol, rtol=tol) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + def test_flash_attention_3_padding_matches_padding_free_with_position_ids(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict: + self.skipTest("Model dummy inputs should contain padding in their attention mask") + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # ensure left padding, to adapt for some models + if 0 in inputs_dict["attention_mask"][:, -1]: + inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1) + dummy_attention_mask = inputs_dict["attention_mask"] + inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.pad_token_id + + model = ( + model_class.from_pretrained( + tmpdirname, + torch_dtype=torch.float16, + attn_implementation="flash_attention_3", + low_cpu_mem_usage=True, + ) + .to(torch_device) + .eval() + ) + + # flatten + padfree_inputs_dict = { + k: v[dummy_attention_mask.bool()].unsqueeze(0) + for k, v in inputs_dict.items() + if not k == "attention_mask" + } + # add position_ids + padfree_inputs_dict["position_ids"] = ( + torch.cat([torch.arange(length) for length in dummy_attention_mask.sum(1).tolist()]) + .long() + .unsqueeze(0) + .to(torch_device) + ) + + res_padded = model(**inputs_dict) + res_padfree = model(**padfree_inputs_dict) + + logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()] + logits_padfree = res_padfree.logits[0] + + torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), atol=0, rtol=0) + # acceptable numerical instability + tol = torch.finfo(torch.float16).eps + torch.testing.assert_close(logits_padded, logits_padfree, atol=tol, rtol=tol) + @is_pt_tf_cross_test def test_tf_from_pt_safetensors(self): for model_class in self.all_model_classes: @@ -4633,6 +5193,54 @@ def test_flash_attn_2_from_config(self): self.assertFalse(fa2_correctly_converted) + @require_flash_attn_3 + @require_torch_gpu + @mark.flash_attn_3_test + @slow + def test_flash_attn_3_from_config(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn_3: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 3") + + config, _ = self.model_tester.prepare_config_and_inputs_for_common() + # TODO: to change it in the future with other relevant auto classes + fa3_model = AutoModelForCausalLM.from_config( + config, attn_implementation="flash_attention_3", torch_dtype=torch.bfloat16 + ).to(torch_device) + + dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) + dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) + + fa3_correctly_converted = False + + for _, module in fa3_model.named_modules(): + if "FlashAttention" in module.__class__.__name__: + fa3_correctly_converted = True + break + + self.assertTrue(fa3_correctly_converted) + + _ = fa3_model(input_ids=dummy_input, attention_mask=dummy_attention_mask) + + with tempfile.TemporaryDirectory() as tmpdirname: + fa3_model.save_pretrained(tmpdirname) + + model_from_pretrained = AutoModelForCausalLM.from_pretrained(tmpdirname) + + self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_3") + + fa3_correctly_converted = False + + for _, module in model_from_pretrained.named_modules(): + if "FlashAttention" in module.__class__.__name__: + fa3_correctly_converted = True + break + + self.assertFalse(fa3_correctly_converted) + def _get_custom_4d_mask_test_data(self): # Sequence in which all but the last token is the same input_ids = torch.tensor( diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 2130ed4b7c887f..1ab76923d4e756 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -67,6 +67,7 @@ ) from transformers.utils.import_utils import ( is_flash_attn_2_available, + is_flash_attn_3_available, is_flax_available, is_tf_available, is_torch_sdpa_available, @@ -549,10 +550,14 @@ def test_model_from_pretrained_attn_implementation(self): if is_flash_attn_2_available(): attn_implementation_available.append("flash_attention_2") + if is_flash_attn_3_available(): + attn_implementation_available.append("flash_attention_3") + mistral_attention_classes = { "eager": "MistralAttention", "sdpa": "MistralSdpaAttention", - "flash_attention_2": "MistralFlashAttention2", + "flash_attention_2": "MistralFlashAttention", + "flash_attention_3": "MistralFlashAttention", } for requested_attn_implementation in attn_implementation_available: model = AutoModelForCausalLM.from_pretrained( @@ -588,10 +593,14 @@ def test_model_from_config_attn_implementation(self): if is_flash_attn_2_available(): attn_implementation_available.append("flash_attention_2") + if is_flash_attn_3_available(): + attn_implementation_available.append("flash_attention_3") + mistral_attention_classes = { "eager": "MistralAttention", "sdpa": "MistralSdpaAttention", - "flash_attention_2": "MistralFlashAttention2", + "flash_attention_2": "MistralFlashAttention", + "flash_attention_3": "MistralFlashAttention", } for requested_attn_implementation in attn_implementation_available: config = AutoConfig.from_pretrained(TINY_MISTRAL, attn_implementation=requested_attn_implementation) @@ -2471,6 +2480,14 @@ def test_error_no_flash_available(self): self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception)) + def test_error_no_flash_3_available(self): + with self.assertRaises(ValueError) as cm: + _ = AutoModel.from_pretrained( + "hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="flash_attention_3" + ) + + self.assertTrue("does not support Flash Attention 3.0" in str(cm.exception)) + def test_error_no_flash_available_with_config(self): with self.assertRaises(ValueError) as cm: config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel") @@ -2481,6 +2498,16 @@ def test_error_no_flash_available_with_config(self): self.assertTrue("does not support Flash Attention 2.0" in str(cm.exception)) + def test_error_no_flash_3_available_with_config(self): + with self.assertRaises(ValueError) as cm: + config = AutoConfig.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel") + + _ = AutoModel.from_pretrained( + "hf-tiny-model-private/tiny-random-MCTCTModel", config=config, attn_implementation="flash_attention_3" + ) + + self.assertTrue("does not support Flash Attention 3.0" in str(cm.exception)) + def test_error_wrong_attn_implementation(self): with self.assertRaises(ValueError) as cm: _ = AutoModel.from_pretrained("hf-tiny-model-private/tiny-random-MCTCTModel", attn_implementation="foo") @@ -2497,6 +2524,16 @@ def test_not_available_flash(self): ) self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception)) + def test_not_available_flash_3(self): + if is_flash_attn_3_available(): + self.skipTest(reason="Please uninstall flash_attn_interface package to run test_not_available_flash_3") + + with self.assertRaises(ImportError) as cm: + _ = AutoModel.from_pretrained( + "hf-internal-testing/tiny-random-GPTBigCodeModel", attn_implementation="flash_attention_3" + ) + self.assertTrue("the package flash_attn_interface seems to be not installed" in str(cm.exception)) + def test_not_available_flash_with_config(self): if is_flash_attn_2_available(): self.skipTest(reason="Please uninstall flash-attn package to run test_not_available_flash") @@ -2512,6 +2549,23 @@ def test_not_available_flash_with_config(self): self.assertTrue("the package flash_attn seems to be not installed" in str(cm.exception)) + def test_not_available_flash_3_with_config(self): + if is_flash_attn_3_available(): + self.skipTest( + reason="Please uninstall flash_attn_interface package to run test_not_available_flash_3_with_config" + ) + + config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-GPTBigCodeModel") + + with self.assertRaises(ImportError) as cm: + _ = AutoModel.from_pretrained( + "hf-internal-testing/tiny-random-GPTBigCodeModel", + config=config, + attn_implementation="flash_attention_3", + ) + + self.assertTrue("the package flash_attn_interface seems to be not installed" in str(cm.exception)) + def test_not_available_sdpa(self): if is_torch_sdpa_available(): self.skipTest(reason="This test requires torch<=2.0")