diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 36775d8454ab8c..ef735b95c5ef99 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1332,6 +1332,7 @@ "convert_and_export_with_cache", ] + _import_structure["modeling_flash_attention_3_utils"] = [] _import_structure["modeling_flash_attention_utils"] = [] _import_structure["modeling_outputs"] = [] _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"] diff --git a/src/transformers/modeling_flash_attention_3_utils.py b/src/transformers/modeling_flash_attention_3_utils.py new file mode 100644 index 00000000000000..0a3faf784a06af --- /dev/null +++ b/src/transformers/modeling_flash_attention_3_utils.py @@ -0,0 +1,295 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + +from .utils import is_flash_attn_3_available + + +if is_flash_attn_3_available(): + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from flash_attn_interface import _flash_attn_forward, flash_attn_func, flash_attn_varlen_func + + +def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def prepare_fa2_from_position_ids(query, key, value, position_ids): + """ + This function returns necessary arguments to call `flash_attn_varlen_func`. + All three query, key, value states will be flattened. + Cummulative lengths of each examples in the batch will be extracted from position_ids. + + NOTE: ideally cummulative lengths should be prepared at the data collator stage + + Arguments: + query (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + query (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.view(-1, key.size(-2), key.size(-1)) + value = value.view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + + cu_seq_lens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) + + max_length = position_ids.max() + 1 + + return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + + +def _flash_attention_3_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool, + position_ids: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + use_top_left_mask: bool = False, + deterministic: bool = None, + descale: float = 1.0, + use_fp8: bool = False, +): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_top_left_mask (`bool`, defaults to `False`): + flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. + deterministic (`bool`, *optional*): + Determines if the deterministic option enabled. + """ + use_fp8 = os.environ.get("FLASH_ATTENTION_3_FP8", "0") == "1" + + softmax_scale = softmax_scale or query_states.shape[-1] ** (-0.5) + + if not use_top_left_mask: + causal = is_causal + else: + causal = is_causal and query_length != 1 + + flash_kwargs = {} + + if deterministic is None: + deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + flash_kwargs["deterministic"] = deterministic + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad, _ = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + + # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing + # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. + # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach + elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1: + batch_size = query_states.size(0) + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output, _ = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + + else: + if use_fp8: + # NOTE: descale? + attn_output = _flash_attn_forward( + query_states.to(torch.float8_e4m3fn), + key_states.to(torch.float8_e4m3fn), + value_states.to(torch.float8_e4m3fn), + softmax_scale=softmax_scale, + causal=causal, + )[0] + else: + attn_output, _ = flash_attn_func( + query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs + ) + + return attn_output diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d4069766636053..314e3ee22cafdf 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -82,6 +82,7 @@ is_accelerate_available, is_bitsandbytes_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_offline_mode, is_optimum_available, is_peft_available, @@ -1367,6 +1368,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Flash Attention 2 support _supports_flash_attn_2 = False + # Flash Attention 3 support + _supports_flash_attn_3 = False + # SDPA support _supports_sdpa = False @@ -1550,10 +1554,12 @@ def _autoset_attn_implementation( ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' ) - if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]: + if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2", "flash_attention_3"]: message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' if cls._supports_flash_attn_2: message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' + if cls._supports_flash_attn_3: + message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)' if cls._supports_sdpa: message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' raise ValueError(message + ".") @@ -1575,6 +1581,14 @@ def _autoset_attn_implementation( hard_check_only=False, check_device_map=check_device_map, ) + elif config._attn_implementation == "flash_attention_3": + cls._check_and_enable_flash_attn_3( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. config = cls._check_and_enable_sdpa( @@ -1734,6 +1748,83 @@ def _check_and_enable_flash_attn_2( config._attn_implementation = "flash_attention_2" return config + @classmethod + def _check_and_enable_flash_attn_3( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + hard_check_only: bool = False, + ) -> PretrainedConfig: + """ + Checks the availability of Flash Attention 3 and compatibility with the current model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module. + """ + if not cls._supports_flash_attn_3: + raise ValueError( + f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where" + f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new" + " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" + ) + + if not is_flash_attn_3_available(): + preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:" + # TODO: docs + install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-3 to install Flash Attention 3." + + if importlib.util.find_spec("flash_attn_interface") is None: + raise ImportError( + f"{preface} the package flash_attn_interface seems to be not installed. {install_message}" + ) + + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + + if _is_bettertransformer: + raise ValueError( + "Flash Attention 3 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" + ) + + if torch_dtype is None: + logger.warning_once( + "You are attempting to use Flash Attention 3.0 without specifying a torch dtype. This might lead to unexpected behaviour" + ) + elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: + logger.warning_once( + "Flash Attention 3.0 only supports torch.float16 and torch.bfloat16 dtypes, but" + f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," + ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`' + ) + + # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, + # or the model may be initialized under the context manager `with torch.device("cuda"):`. + if check_device_map and device_map is None and torch.empty(0).device.type != "cuda": + if torch.cuda.is_available(): + logger.warning_once( + "You are attempting to use Flash Attention 3.0 with a model not initialized on GPU. Make sure to move the model to GPU" + " after initializing it on CPU with `model.to('cuda')`." + ) + else: + raise ValueError( + "You are attempting to use Flash Attention 3.0 with a model not initialized on GPU and with no GPU available. " + "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " + "or initialising the model on CPU and then moving it to GPU." + ) + elif ( + check_device_map + and device_map is not None + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + raise ValueError( + "You are attempting to use Flash Attention 3.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " + "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." + ) + if not hard_check_only: + config._attn_implementation = "flash_attention_3" + return config + @classmethod def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: """ diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index d32ab95f51c7da..bdc86c34fb6bec 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -790,7 +790,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 23334311ca9511..3597f1b00d5dac 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -538,6 +539,112 @@ def forward( return attn_output, attn_weights, past_key_value +class ChameleonFlashAttention3(ChameleonAttention): + """ + Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + query_states = self.q_norm(query_states) + + key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) + key_states = self.k_norm(key_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. + # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (ChameleonRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class ChameleonSdpaAttention(ChameleonAttention): """ Chameleon attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -634,6 +741,7 @@ def forward( CHAMELEON_ATTENTION_CLASSES = { "eager": ChameleonAttention, "flash_attention_2": ChameleonFlashAttention2, + "flash_attention_3": ChameleonFlashAttention3, "sdpa": ChameleonSdpaAttention, } @@ -1155,6 +1263,7 @@ class ChameleonPreTrainedModel(PreTrainedModel): _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"] _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_quantized_cache = True _supports_cache_class = True @@ -1433,7 +1542,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 46eea43e1285f8..2b3a953601a682 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -633,7 +633,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index ae84a9ec2d1a43..01781cf78caa76 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -44,6 +44,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -55,6 +56,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -539,6 +543,118 @@ def forward( return attn_output, attn_weights, past_key_value +class CohereFlashAttention3(CohereAttention): + """ + Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + if self.use_qk_norm: + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (CohereLayerNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class CohereSdpaAttention(CohereAttention): """ Cohere attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -645,6 +761,7 @@ def forward( COHERE_ATTENTION_CLASSES = { "eager": CohereAttention, "flash_attention_2": CohereFlashAttention2, + "flash_attention_3": CohereFlashAttention3, "sdpa": CohereSdpaAttention, } @@ -751,6 +868,7 @@ class CoherePreTrainedModel(PreTrainedModel): _no_split_modules = ["CohereDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1005,7 +1123,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 43bac44ba1be20..4bffc919bf3644 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -30,6 +30,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -40,6 +41,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DbrxConfig" @@ -484,6 +488,117 @@ def forward( return attn_output, attn_weights, past_key_value +class DbrxFlashAttention3(DbrxAttention): + """Dbrx flash attention module. + + This module inherits from `DbrxAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it + calls the public API of flash attention. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Any, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.") + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.Wqkv(hidden_states) + if self.clip_qkv is not None: + qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + query_states, key_states, value_states = qkv_states.split( + [ + self.hidden_size, + self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ], + dim=2, + ) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = query_states.dtype + + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be " + + "related to the fact you have upcasted embedding or layer norm layers in " + + f"float32. We will cast back the input in {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class DbrxSdpaAttention(DbrxAttention): """ Dbrx attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -582,6 +697,7 @@ def forward( DBRX_ATTENTION_CLASSES = { "eager": DbrxAttention, "flash_attention_2": DbrxFlashAttention2, + "flash_attention_3": DbrxFlashAttention3, "sdpa": DbrxSdpaAttention, } @@ -882,6 +998,7 @@ class DbrxPreTrainedModel(PreTrainedModel): _no_split_modules = ["DbrxBlock"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1164,7 +1281,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 73e8806352ebbd..174bcde1086f65 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -43,6 +43,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, ) @@ -55,6 +56,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -688,6 +692,110 @@ def forward( return attn_output, layer_past, attn_weights +class FalconFlashAttention3(FalconAttention): + """ + Falcon flash attention module. This module inherits from `FalconAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Cache] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) + + if alibi is None: + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_layer, position_ids) + else: + cos, sin = position_embeddings + query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin) + + if layer_past is not None: + cache_kwargs = {"cache_position": cache_position} + if alibi is None: + cache_kwargs.update({"sin": sin, "cos": cos}) + key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_layer = query_layer.transpose(1, 2) + key_layer = key_layer.transpose(1, 2) + value_layer = value_layer.transpose(1, 2) + + if alibi is not None: + raise ValueError("`alibi` is not supported when `use_flash_attn` is True") + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_layer.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.query_key_value.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_layer = query_layer.to(target_dtype) + key_layer = key_layer.to(target_dtype) + value_layer = value_layer.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_layer, + key_layer, + value_layer, + attention_mask, + query_length, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + attn_output = self.dense(attn_weights) + + if not output_attentions: + attn_weights = None + + return attn_output, layer_past, attn_weights + + class FalconMLP(nn.Module): def __init__(self, config: FalconConfig): super().__init__() @@ -708,6 +816,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: "eager": FalconAttention, "sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA "flash_attention_2": FalconFlashAttention2, + "flash_attention_3": FalconFlashAttention3, } @@ -909,6 +1018,7 @@ class FalconPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["FalconDecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1157,7 +1267,10 @@ def _update_causal_mask( # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gemma/diff_gemma.py b/src/transformers/models/gemma/diff_gemma.py index fcdb0f0b3d7d8f..53b656282eb46a 100644 --- a/src/transformers/models/gemma/diff_gemma.py +++ b/src/transformers/models/gemma/diff_gemma.py @@ -34,6 +34,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import CausalLMOutputWithPast from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -441,6 +442,103 @@ def forward( return attn_output, attn_weights, past_key_value +# TODO felix: does this inheritance really work out in the end to GemmaFlashAttention2 inheriting form GemmaAttention? +class GemmaFlashAttention3(LlamaFlashAttention2): + """ + Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (GemmaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class GemmaModel(LlamaModel): def forward( self, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index b14e0a4b3d8ca5..e13b9cfdd9325b 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -467,6 +468,106 @@ def forward( return attn_output, attn_weights, past_key_value +class GemmaFlashAttention3(GemmaAttention): + """ + Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (GemmaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class GemmaSdpaAttention(GemmaAttention): """ Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -558,6 +659,7 @@ def forward( GEMMA_ATTENTION_CLASSES = { "eager": GemmaAttention, "flash_attention_2": GemmaFlashAttention2, + "flash_attention_3": GemmaFlashAttention3, "sdpa": GemmaSdpaAttention, } @@ -665,6 +767,7 @@ class GemmaPreTrainedModel(PreTrainedModel): _no_split_modules = ["GemmaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -932,7 +1035,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gemma2/diff_gemma2.py b/src/transformers/models/gemma2/diff_gemma2.py index 30f371a1b61267..14ac971f6be21e 100644 --- a/src/transformers/models/gemma2/diff_gemma2.py +++ b/src/transformers/models/gemma2/diff_gemma2.py @@ -439,7 +439,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 1909ef78501559..e5a0c8a5853d0d 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -568,7 +568,10 @@ def forward( ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding # Flash-attn is a 2D tensor - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if past_key_value is not None: # when decoding attention_mask = attention_mask[:, -self.sliding_window :] else: @@ -919,7 +922,10 @@ def _update_causal_mask( # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible # as it doesn't cause dynamic control issues. - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): return attention_mask dtype, device = input_tensor.dtype, input_tensor.device diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 72590862b749f0..571d15ae4f3cd1 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -41,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torch_fx_available, logging, @@ -51,6 +52,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. @@ -426,9 +429,102 @@ def forward( return outputs +class GPTNeoFlashAttention3(GPTNeoSelfAttention): + """ + GPTNeo flash attention module. This module inherits from `GPTNeoSelfAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states, + attention_mask=None, + layer_past=None, + head_mask=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + bsz, _, _ = hidden_states.size() + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + cache_kwargs = {"cache_position": cache_position} + key, value = layer_past.update(key, value, self.layer_id, cache_kwargs) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + attention_mask = attention_mask[:, :, :, : key.shape[-2]] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_length, + softmax_scale=1.0, + is_causal=self.is_causal, + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, layer_past) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + GPT_NEO_ATTENTION_CLASSES = { "eager": GPTNeoSelfAttention, "flash_attention_2": GPTNeoFlashAttention2, + "flash_attention_3": GPTNeoFlashAttention3, } @@ -550,6 +646,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTNeoBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = False # TODO: needs a HybridCache @@ -847,7 +944,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index e88302efa7bb04..d935ac65a3a722 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -43,6 +43,7 @@ from ...utils import ( get_torch_version, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, ) @@ -52,6 +53,10 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM" @@ -125,6 +130,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTNeoXLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True @@ -459,6 +465,100 @@ def forward( return outputs +class GPTNeoXFlashAttention3(GPTNeoXAttention): + """ + GPTNeoX flash attention module. This module inherits from `GPTNeoXAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + position_ids: torch.LongTensor, + head_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + # Apply attention-specific projections and rope + query, key, value, present = self._attn_projections_and_rope( + hidden_states=hidden_states, + position_ids=position_ids, + layer_past=layer_past, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + query_length = query.shape[-2] + + # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision + target_dtype = value.dtype + if query.dtype != target_dtype: + query = query.to(target_dtype) + if key.dtype != target_dtype: + key = key.to(target_dtype) + + # Permute to get the expected shape for Flash Attention + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 / bfloat16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + input_dtype = query.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.query_key_value.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + # Compute attention + attn_weights = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_length, + softmax_scale=self.norm_factor, + is_causal=self.is_causal, + ) + + # Reshape outputs + attn_output = attn_weights.reshape( + attn_weights.shape[0], attn_weights.shape[1], self.num_attention_heads * self.head_size + ) + attn_output = self.dense(attn_output) + + outputs = (attn_output, layer_past) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + class GPTNeoXSdpaAttention(GPTNeoXAttention): """ GPTNeoX attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -727,6 +827,7 @@ def forward(self, hidden_states): GPT_NEOX_ATTENTION_CLASSES = { "eager": GPTNeoXAttention, "flash_attention_2": GPTNeoXFlashAttention2, + "flash_attention_3": GPTNeoXFlashAttention3, "sdpa": GPTNeoXSdpaAttention, } @@ -1044,7 +1145,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index bf832195b4efc3..801a6f0d778141 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -748,7 +748,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index bd7ce5696fa077..1a2f71029675f2 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -38,6 +38,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torch_fx_proxy, logging, @@ -50,6 +51,10 @@ from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "hf-internal-testing/tiny-random-gptj" @@ -457,9 +462,145 @@ def forward( return outputs +class GPTJFlashAttention3(GPTJAttention): + """ + GPTJ flash attention module. This module inherits from `GPTJAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.FloatTensor, + layer_past: Optional[Cache] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) + value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) + + if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing(): + # The logic to conditionally copy to GPU could not be traced, so we do this + # every time in the torch.fx case + embed_positions = get_embed_positions(self.embed_positions, position_ids) + else: + embed_positions = self._get_embed_positions(position_ids) + + repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) + sincos = torch.gather(embed_positions, 1, repeated_position_ids) + sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + k_rot = apply_rotary_pos_emb(k_rot, sin, cos) + q_rot = apply_rotary_pos_emb(q_rot, sin, cos) + + key = torch.cat([k_rot, k_pass], dim=-1) + query = torch.cat([q_rot, q_pass], dim=-1) + else: + key = apply_rotary_pos_emb(key, sin, cos) + query = apply_rotary_pos_emb(query, sin, cos) + + # tanspose to have the desired shape + # before transpose: batch_size x seq_length x num_attention_heads x head_dim + # after transpose: batch_size x num_attention_heads x seq_length x head_dim + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + # value: batch_size x num_attention_heads x seq_length x head_dim + + if layer_past is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_dim, + "cache_position": cache_position, + } + key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs) + + # The Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we need to keep the original shape for query and key, and reshape value + # to have the correct shape. + key = key.permute(0, 2, 1, 3).contiguous() + query = query.permute(0, 2, 1, 3).contiguous() + value = value.permute(0, 2, 1, 3).contiguous() + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + query_length = query.shape[1] + + # Compute attention + attn_weights = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_length, + is_causal=self.is_causal, + ) + + # Reshape outputs + attn_output = attn_weights.reshape( + attn_weights.shape[0], attn_weights.shape[1], attn_weights.shape[2] * attn_weights.shape[3] + ) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, layer_past) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + GPTJ_ATTENTION_CLASSES = { "eager": GPTJAttention, "flash_attention_2": GPTJFlashAttention2, + "flash_attention_3": GPTJFlashAttention3, } @@ -540,6 +681,7 @@ class GPTJPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTJBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True @@ -942,7 +1084,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 876f5ed2a7c8da..d8929470a912f6 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -453,6 +454,102 @@ def forward( return attn_output, attn_weights, past_key_value +class GraniteFlashAttention3(GraniteAttention): + """ + Granite flash attention module. This module inherits from `GraniteAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (GraniteRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + softmax_scale=self.scaling, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class GraniteSdpaAttention(GraniteAttention): """ Granite attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -547,6 +644,7 @@ def forward( GRANITE_ATTENTION_CLASSES = { "eager": GraniteAttention, "flash_attention_2": GraniteFlashAttention2, + "flash_attention_3": GraniteFlashAttention3, "sdpa": GraniteSdpaAttention, } @@ -662,6 +760,7 @@ class GranitePreTrainedModel(PreTrainedModel): _no_split_modules = ["GraniteDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -934,7 +1033,10 @@ def _update_causal_mask( # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 0d26c4fc65de78..714dfbc29053ac 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1475,7 +1475,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 60e1670a3c2784..0725d4f058570a 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -48,6 +48,7 @@ from ...utils.import_utils import ( is_causal_conv1d_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_mamba_ssm_available, is_torchdynamo_compiling, @@ -58,6 +59,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + if is_mamba_ssm_available(): from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn @@ -498,6 +502,119 @@ def forward( return attn_output, attn_weights, past_key_value +class JambaFlashAttention3(JambaAttention): + """ + Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = cache_position[-1] + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = cache_position[0] > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba class JambaSdpaAttention(JambaAttention): """ @@ -584,6 +701,7 @@ def forward( JAMBA_ATTENTION_CLASSES = { "eager": JambaAttention, "flash_attention_2": JambaFlashAttention2, + "flash_attention_3": JambaFlashAttention3, "sdpa": JambaSdpaAttention, } @@ -1121,6 +1239,7 @@ class JambaPreTrainedModel(PreTrainedModel): _no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True @@ -1377,7 +1496,10 @@ def forward( ) def _update_causal_mask(self, attention_mask, input_tensor, cache_position): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index c273b021d73664..6793d7ec664161 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -36,6 +36,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -47,6 +48,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "jetmoe" @@ -804,9 +808,116 @@ def forward( return attn_output, attn_weights, past_key_value, router_logits +class JetMoeFlashAttention3(JetMoeAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: + """ + Forward pass of the JetMoeAttention module. + + Args: + hidden_states (Optional[torch.FloatTensor]): Input hidden states. + attention_mask (Optional[torch.FloatTensor]): Attention mask. + layer_past (Optional[Tuple[torch.Tensor]]): Past layer state. + use_cache (Optional[bool]): Whether to use cached states. + output_attentions (Optional[bool]): Whether to output attention weights. + cache_position (Optional[torch.LongTensor]): Position of the cache. + + Returns: + Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[...]]]: Tuple containing outputs. + """ + output_attentions = False + bsz, q_len, hidden_size = hidden_states.size() + + # calculate query, key, values + query_states, router_logits, topo_info = self.experts.map(hidden_states) + key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads for top-k attention experts + key_states = key_states.repeat(1, self.top_k, 1, 1) + value_states = value_states.repeat(1, self.top_k, 1, 1) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.kv_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ).to(input_dtype) + + # output projection + attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size) + attn_output = self.experts.reduce(attn_output, topo_info) + attn_output = attn_output.view(bsz, q_len, hidden_size) # re-assemble all head outputs side by side + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value, router_logits + + JETMOE_ATTENTION_CLASSES = { "eager": JetMoeAttention, "flash_attention_2": JetMoeFlashAttention2, + "flash_attention_3": JetMoeFlashAttention3, "sdpa": JetMoeSdpaAttention, } @@ -879,6 +990,7 @@ class JetMoePreTrainedModel(PreTrainedModel): _no_split_modules = ["JetMoeBlock"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True @@ -1136,7 +1248,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c7017832b9324c..24b02ec4ea7697 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -569,6 +570,117 @@ def forward( return attn_output, attn_weights, past_key_value +class LlamaFlashAttention3(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # TODO: get `use_fp8` to here, add attention_kwargs or something + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class LlamaSdpaAttention(LlamaAttention): """ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -671,6 +783,7 @@ def forward( LLAMA_ATTENTION_CLASSES = { "eager": LlamaAttention, "flash_attention_2": LlamaFlashAttention2, + "flash_attention_3": LlamaFlashAttention3, "sdpa": LlamaSdpaAttention, } @@ -783,6 +896,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1039,7 +1153,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 9455f21b2073ff..7597ea1ad693f8 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -46,6 +46,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -56,6 +57,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -410,6 +413,125 @@ def forward( return attn_output, attn_weights, past_key_value +class MBartFlashAttention3(MBartAttention): + """ + MBart flash attention module. This module inherits from `MBartAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # MBartFlashAttention3 attention does not support output_attentions + if output_attentions: + raise ValueError("MBartFlashAttention3 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MBart class MBartSdpaAttention(MBartAttention): def forward( @@ -521,6 +643,7 @@ def forward( "eager": MBartAttention, "sdpa": MBartSdpaAttention, "flash_attention_2": MBartFlashAttention2, + "flash_attention_3": MBartFlashAttention3, } @@ -745,6 +868,7 @@ class MBartPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MBartDecoderLayer", "MBartAttention"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): @@ -1043,7 +1167,10 @@ def forward( # expand attention_mask if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): attention_mask = attention_mask if 0 in attention_mask else None elif self.config._attn_implementation == "sdpa" and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA, fall back to @@ -1260,7 +1387,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self.config._attn_implementation == "sdpa" and not output_attentions and cross_attn_head_mask is None: @@ -1280,7 +1410,10 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self.config._attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index c7062e75b1085c..167054fbe61bd6 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -43,6 +43,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -54,6 +55,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. if is_torch_fx_available(): @@ -561,6 +565,133 @@ def forward( return attn_output, attn_weights, past_key_value +# TODO @longjie no longer copied from Mistral after static cache +class MixtralFlashAttention3(MixtralAttention): + """ + Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = ( + max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len + ) + + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral # TODO @longjie no longer copied from Mistral after static cache class MixtralSdpaAttention(MixtralAttention): @@ -656,6 +787,7 @@ def forward( MIXTRAL_ATTENTION_CLASSES = { "eager": MixtralAttention, "flash_attention_2": MixtralFlashAttention2, + "flash_attention_3": MixtralFlashAttention3, "sdpa": MixtralSdpaAttention, } @@ -857,6 +989,7 @@ class MixtralPreTrainedModel(PreTrainedModel): _no_split_modules = ["MixtralDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True @@ -1122,7 +1255,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 4d079b4dde104d..c0b9fe68e64cda 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -468,6 +469,108 @@ def forward( return attn_output, attn_weights, past_key_value +class NemotronFlashAttention3(NemotronAttention): + """ + Nemotron flash attention module. This module inherits from `NemotronAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (NemotronRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron class NemotronSdpaAttention(NemotronAttention): """ @@ -563,6 +666,7 @@ def forward( NEMOTRON_ATTENTION_CLASSES = { "eager": NemotronAttention, "flash_attention_2": NemotronFlashAttention2, + "flash_attention_3": NemotronFlashAttention3, "sdpa": NemotronSdpaAttention, } @@ -677,6 +781,7 @@ class NemotronPreTrainedModel(PreTrainedModel): _no_split_modules = ["NemotronDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -920,7 +1025,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index b4bda8e2db5251..4f1ebb3cf690c1 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -41,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -52,6 +53,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -498,6 +502,106 @@ def forward( return attn_output, attn_weights, past_key_value +class OlmoFlashAttention3(OlmoAttention): + """ + OLMo flash attention module. This module inherits from `OlmoAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (OlmoRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class OlmoSdpaAttention(OlmoAttention): """ OLMo attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -594,6 +698,7 @@ def forward( OLMO_ATTENTION_CLASSES = { "eager": OlmoAttention, "flash_attention_2": OlmoFlashAttention2, + "flash_attention_3": OlmoFlashAttention3, "sdpa": OlmoSdpaAttention, } @@ -704,6 +809,7 @@ class OlmoPreTrainedModel(PreTrainedModel): _no_split_modules = ["OlmoDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -959,7 +1065,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index a33338365312db..8efab00cc86cbf 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -34,6 +34,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -44,6 +45,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -572,6 +576,105 @@ def forward( return attn_output, attn_weights, past_key_value +class OlmoeFlashAttention3(OlmoeAttention): + """ + OLMoE flash attention module. This module inherits from `OlmoeAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (OlmoeRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class OlmoeSdpaAttention(OlmoeAttention): """ OLMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -670,6 +773,7 @@ def forward( OLMOE_ATTENTION_CLASSES = { "eager": OlmoeAttention, "flash_attention_2": OlmoeFlashAttention2, + "flash_attention_3": OlmoeFlashAttention3, "sdpa": OlmoeSdpaAttention, } @@ -838,6 +942,7 @@ class OlmoePreTrainedModel(PreTrainedModel): _no_split_modules = ["OlmoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1109,7 +1214,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index ccaa2c7fd29aae..02989bc9bc5471 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -783,7 +783,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 648d1653a3b503..32270fa03414c8 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -41,6 +41,7 @@ add_start_docstrings_to_model_forward, get_torch_version, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -52,6 +53,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -575,6 +578,136 @@ def forward( return attn_output, attn_weights, past_key_value +class PhiFlashAttention3(PhiAttention): + """ + Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # PhiFlashAttention3 attention does not support output_attentions + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + softmax_scale=None, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.dense(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class PhiSdpaAttention(PhiAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -704,6 +837,7 @@ def forward( PHI_ATTENTION_CLASSES = { "eager": PhiAttention, "flash_attention_2": PhiFlashAttention2, + "flash_attention_3": PhiFlashAttention3, "sdpa": PhiSdpaAttention, } @@ -812,6 +946,7 @@ class PhiPreTrainedModel(PreTrainedModel): _no_split_modules = ["PhiDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True @@ -1074,7 +1209,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index ec395679ae6207..02938c889e628b 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1095,7 +1095,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index d0ea8ef0e376e0..90fd0e2158beb1 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -42,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -53,6 +54,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -539,6 +543,132 @@ def forward( return attn_output, attn_weights, past_key_value +class Qwen2FlashAttention3(Qwen2Attention): + """ + Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + kv_seq_len = key_states.shape[-2] + cache_position[0] + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class Qwen2SdpaAttention(Qwen2Attention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -638,6 +768,7 @@ def forward( QWEN2_ATTENTION_CLASSES = { "eager": Qwen2Attention, "flash_attention_2": Qwen2FlashAttention2, + "flash_attention_3": Qwen2FlashAttention3, "sdpa": Qwen2SdpaAttention, } @@ -754,6 +885,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1013,7 +1145,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 6f483e50cde065..de9aa69bc620b3 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -43,6 +43,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -54,6 +55,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "Qwen/Qwen1.5-MoE-A2.7B" @@ -621,6 +625,132 @@ def forward( return attn_output, attn_weights, past_key_value +class Qwen2MoeFlashAttention3(Qwen2MoeAttention): + """ + Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + kv_seq_len = key_states.shape[-2] + cache_position[0] + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe class Qwen2MoeSdpaAttention(Qwen2MoeAttention): """ @@ -721,6 +851,7 @@ def forward( QWEN2MOE_ATTENTION_CLASSES = { "eager": Qwen2MoeAttention, "flash_attention_2": Qwen2MoeFlashAttention2, + "flash_attention_3": Qwen2MoeFlashAttention3, "sdpa": Qwen2MoeSdpaAttention, } @@ -913,6 +1044,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2MoeDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True @@ -1188,7 +1320,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 4e4e04198c0bf1..22c08de4f6b3e9 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -44,6 +44,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -58,6 +59,12 @@ else: flash_attn_varlen_func = None +if is_flash_attn_3_available(): + from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func + + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward +else: + flash_attn_3_varlen_func = None logger = logging.get_logger(__name__) @@ -379,6 +386,29 @@ def forward( return attn_output +class VisionFlashAttention3(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_3_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + class VisionSdpaAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() @@ -410,6 +440,7 @@ def forward( QWEN2_VL_VISION_ATTENTION_CLASSES = { "eager": VisionAttention, "flash_attention_2": VisionFlashAttention2, + "flash_attention_3": VisionFlashAttention3, "sdpa": VisionSdpaAttention, } @@ -812,6 +843,144 @@ def forward( return attn_output, attn_weights, past_key_value +class Qwen2VLFlashAttention3(Qwen2VLAttention): + """ + Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class Qwen2VLSdpaAttention(Qwen2VLAttention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -917,6 +1086,7 @@ def forward( QWEN2_VL_ATTENTION_CLASSES = { "eager": Qwen2VLAttention, "flash_attention_2": Qwen2VLFlashAttention2, + "flash_attention_3": Qwen2VLFlashAttention3, "sdpa": Qwen2VLSdpaAttention, } @@ -1033,6 +1203,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True @@ -1274,7 +1445,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index d91c0832ed33da..11348590ce0cc4 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -42,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -52,6 +53,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -672,10 +676,114 @@ def forward( return attn_output, attn_weights, past_key_value +class StableLmFlashAttention3(StableLmAttention): + """ + StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # StableLmFlashAttention3 attention does not support output_attentions + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + ATTENTION_CLASSES = { "eager": StableLmAttention, "sdpa": StableLmSdpaAttention, "flash_attention_2": StableLmFlashAttention2, + "flash_attention_3": StableLmFlashAttention3, } @@ -799,6 +907,7 @@ class StableLmPreTrainedModel(PreTrainedModel): _no_split_modules = ["StableLmDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_sdpa = True _supports_quantized_cache = True @@ -1058,7 +1167,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 0be37c4e1fb91c..bb0311477be828 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -42,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -53,6 +54,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -512,6 +516,131 @@ def forward( return attn_output, attn_weights, past_key_value +class Starcoder2FlashAttention3(Starcoder2Attention): + """ + Starcoder2 flash attention module. This module inherits from `Starcoder2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + kv_seq_len = key_states.shape[-2] + cache_position[0] + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Starcoder2 class Starcoder2SdpaAttention(Starcoder2Attention): """ @@ -614,6 +743,7 @@ def forward( STARCODER2_ATTENTION_CLASSES = { "eager": Starcoder2Attention, "flash_attention_2": Starcoder2FlashAttention2, + "flash_attention_3": Starcoder2FlashAttention3, "sdpa": Starcoder2SdpaAttention, } @@ -728,6 +858,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Starcoder2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -988,7 +1119,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index b82b978e5e6d95..070650a7d0a60d 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -39,6 +39,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -50,6 +51,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -523,6 +527,121 @@ def forward( return attn_output, attn_weights, past_key_value +class WhisperFlashAttention3(WhisperAttention): + """ + Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_2'`. " + "Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers" + ) + # WhisperFlashAttention3 attention does not support output_attentions + if output_attentions: + raise ValueError("WhisperFlashAttention3 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim)) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self._shape(self.k_proj(current_states), -1, bsz) + value_states = self._shape(self.v_proj(current_states), -1, bsz) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim] + # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + causal_mask, + tgt_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class WhisperSdpaAttention(WhisperAttention): def forward( self, @@ -624,6 +743,7 @@ def forward( WHISPER_ATTENTION_CLASSES = { "eager": WhisperAttention, "flash_attention_2": WhisperFlashAttention2, + "flash_attention_3": WhisperFlashAttention3, "sdpa": WhisperSdpaAttention, } @@ -823,6 +943,7 @@ class WhisperPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True @@ -1428,7 +1549,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index eee350349f5565..92a9b1fb675623 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -130,6 +130,7 @@ is_faiss_available, is_fbgemm_gpu_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_flax_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ad8b649aaa4e84..4f4e70b122be6f 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -62,6 +62,14 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ except ImportError: # If the package can't be imported, it's not available package_exists = False + elif pkg_name == "flash_attn_interface": + try: + package = importlib.import_module(pkg_name) + package_version = getattr(package, "__version__", "N/A") + package_exists = True + except ImportError: + # If the package can't be imported, it's not available + package_exists = False else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False @@ -899,6 +907,16 @@ def is_flash_attn_greater_or_equal(library_version: str): return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) +def is_flash_attn_3_available(): + if not is_flash_attn_2_available(): + return False + + if not _is_package_available("flash_attn_interface"): + return False + + return True + + def is_torchdistx_available(): return _torchdistx_available