diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 36775d8454ab8c..ef735b95c5ef99 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1332,6 +1332,7 @@ "convert_and_export_with_cache", ] + _import_structure["modeling_flash_attention_3_utils"] = [] _import_structure["modeling_flash_attention_utils"] = [] _import_structure["modeling_outputs"] = [] _import_structure["modeling_rope_utils"] = ["ROPE_INIT_FUNCTIONS"] diff --git a/src/transformers/modeling_flash_attention_3_utils.py b/src/transformers/modeling_flash_attention_3_utils.py new file mode 100644 index 00000000000000..0a3faf784a06af --- /dev/null +++ b/src/transformers/modeling_flash_attention_3_utils.py @@ -0,0 +1,295 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F + +from .utils import is_flash_attn_3_available + + +if is_flash_attn_3_available(): + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from flash_attn_interface import _flash_attn_forward, flash_attn_func, flash_attn_varlen_func + + +def _get_unpad_data(attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, int]: + """ + Retrieves indexing data required to repad unpadded (ragged) tensors. + + Arguments: + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + indices (`torch.Tensor`): + The indices of non-masked tokens from the flattened input sequence. + cu_seqlens (`torch.Tensor`): + The cumulative sequence lengths, used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + max_seqlen_in_batch (`int`): + Maximum sequence length in batch. + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _upad_input( + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, +): + """ + Unpads query, key, and values tensors, using a single dimension for all tokens even though they belong to different batches. + + This function is used instead of `flash_attn.bert_padding.unpad_input` in order to avoid the recomputation of the same intermediary + tensors for query, key, value tensors. + + Arguments: + query_layer (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + attention_mask (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + query_length (`int`): + Target length. + + Return: + query_layer (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key_layer (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value_layer (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis(query_layer.reshape(batch_size * kv_seq_len, -1, head_dim), indices_k) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +def prepare_fa2_from_position_ids(query, key, value, position_ids): + """ + This function returns necessary arguments to call `flash_attn_varlen_func`. + All three query, key, value states will be flattened. + Cummulative lengths of each examples in the batch will be extracted from position_ids. + + NOTE: ideally cummulative lengths should be prepared at the data collator stage + + Arguments: + query (`torch.Tensor`): + Query state with padding. Shape: (batch_size, query_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (batch_size, kv_seq_len, num_key_value_heads, head_dim). + position_ids (`torch.Tensor`): + Boolean or int tensor of shape (batch_size, sequence_length), 1 means valid and 0 means not valid. + + Return: + query (`torch.Tensor`): + Query state without padding. Shape: (total_target_length, num_heads, head_dim). + key (`torch.Tensor`): + Key state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + value (`torch.Tensor`): + Value state with padding. Shape: (total_source_length, num_key_value_heads, head_dim). + indices_q (`torch.Tensor`): + The indices of non-masked tokens from the flattened input target sequence. + (cu_seqlens_q, cu_seqlens_k) (`Tuple[int]`): + The cumulative sequence lengths for the target (query) and source (key, value), used to index into ragged (unpadded) tensors. `cu_seqlens` shape is (batch_size + 1,). + (max_seqlen_in_batch_q, max_seqlen_in_batch_k) (`Tuple[int]`): + Maximum sequence length in batch (`max_seqlen_in_batch_q` for the target sequence i.e. query, `max_seqlen_in_batch_k` for the source sequence i.e. key/value). + """ + query = query.view(-1, query.size(-2), query.size(-1)) + key = key.view(-1, key.size(-2), key.size(-1)) + value = value.view(-1, value.size(-2), value.size(-1)) + position_ids = position_ids.flatten() + indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32) + + cu_seq_lens = torch.cat( + ( + indices_q[position_ids == 0], + torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32), + ) + ) + + max_length = position_ids.max() + 1 + + return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) + + +def _flash_attention_3_forward( + query_states: torch.Tensor, + key_states: torch.Tensor, + value_states: torch.Tensor, + attention_mask: torch.Tensor, + query_length: int, + is_causal: bool, + position_ids: Optional[torch.Tensor] = None, + softmax_scale: Optional[float] = None, + use_top_left_mask: bool = False, + deterministic: bool = None, + descale: float = 1.0, + use_fp8: bool = False, +): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_top_left_mask (`bool`, defaults to `False`): + flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. + deterministic (`bool`, *optional*): + Determines if the deterministic option enabled. + """ + use_fp8 = os.environ.get("FLASH_ATTENTION_3_FP8", "0") == "1" + + softmax_scale = softmax_scale or query_states.shape[-1] ** (-0.5) + + if not use_top_left_mask: + causal = is_causal + else: + causal = is_causal and query_length != 1 + + flash_kwargs = {} + + if deterministic is None: + deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1" + flash_kwargs["deterministic"] = deterministic + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = _upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad, _ = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + + # If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing + # then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage. + # Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach + elif position_ids is not None and not (torch.diff(position_ids, dim=-1) >= 0).all() and query_length != 1: + batch_size = query_states.size(0) + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids( + query_states, key_states, value_states, position_ids + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output, _ = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + softmax_scale=softmax_scale, + causal=causal, + **flash_kwargs, + ) + + attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1)) + + else: + if use_fp8: + # NOTE: descale? + attn_output = _flash_attn_forward( + query_states.to(torch.float8_e4m3fn), + key_states.to(torch.float8_e4m3fn), + value_states.to(torch.float8_e4m3fn), + softmax_scale=softmax_scale, + causal=causal, + )[0] + else: + attn_output, _ = flash_attn_func( + query_states, key_states, value_states, softmax_scale=softmax_scale, causal=causal, **flash_kwargs + ) + + return attn_output diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d4069766636053..314e3ee22cafdf 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -82,6 +82,7 @@ is_accelerate_available, is_bitsandbytes_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_offline_mode, is_optimum_available, is_peft_available, @@ -1367,6 +1368,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Flash Attention 2 support _supports_flash_attn_2 = False + # Flash Attention 3 support + _supports_flash_attn_3 = False + # SDPA support _supports_sdpa = False @@ -1550,10 +1554,12 @@ def _autoset_attn_implementation( ' We recommend to just use `attn_implementation="flash_attention_2"` when loading the model.' ) - if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2"]: + if config._attn_implementation not in ["eager", "sdpa", "flash_attention_2", "flash_attention_3"]: message = f'Specified `attn_implementation="{config._attn_implementation}"` is not supported. The only possible arguments are `attn_implementation="eager"` (manual attention implementation)' if cls._supports_flash_attn_2: message += ', `"attn_implementation=flash_attention_2"` (implementation using flash attention 2)' + if cls._supports_flash_attn_3: + message += ', `"attn_implementation=flash_attention_3"` (implementation using flash attention 3)' if cls._supports_sdpa: message += ', `"attn_implementation=sdpa"` (implementation using torch.nn.functional.scaled_dot_product_attention)' raise ValueError(message + ".") @@ -1575,6 +1581,14 @@ def _autoset_attn_implementation( hard_check_only=False, check_device_map=check_device_map, ) + elif config._attn_implementation == "flash_attention_3": + cls._check_and_enable_flash_attn_3( + config, + torch_dtype=torch_dtype, + device_map=device_map, + hard_check_only=False, + check_device_map=check_device_map, + ) elif requested_attn_implementation in [None, "sdpa"] and not is_torch_xla_available(): # use_flash_attention_2 takes priority over SDPA, hence SDPA treated in this elif. config = cls._check_and_enable_sdpa( @@ -1734,6 +1748,83 @@ def _check_and_enable_flash_attn_2( config._attn_implementation = "flash_attention_2" return config + @classmethod + def _check_and_enable_flash_attn_3( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + check_device_map: bool = True, + hard_check_only: bool = False, + ) -> PretrainedConfig: + """ + Checks the availability of Flash Attention 3 and compatibility with the current model. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `attn_implementation` to "flash_attention_3" so that the model can initialize the correct attention module. + """ + if not cls._supports_flash_attn_3: + raise ValueError( + f"{cls.__name__} does not support Flash Attention 3.0 yet. Please request to add support where" + f" the model is hosted, on its model hub page: https://huggingface.co/{config._name_or_path}/discussions/new" + " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new" + ) + + if not is_flash_attn_3_available(): + preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:" + # TODO: docs + install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-3 to install Flash Attention 3." + + if importlib.util.find_spec("flash_attn_interface") is None: + raise ImportError( + f"{preface} the package flash_attn_interface seems to be not installed. {install_message}" + ) + + _is_bettertransformer = getattr(cls, "use_bettertransformer", False) + + if _is_bettertransformer: + raise ValueError( + "Flash Attention 3 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()" + ) + + if torch_dtype is None: + logger.warning_once( + "You are attempting to use Flash Attention 3.0 without specifying a torch dtype. This might lead to unexpected behaviour" + ) + elif torch_dtype is not None and torch_dtype not in [torch.float16, torch.bfloat16]: + logger.warning_once( + "Flash Attention 3.0 only supports torch.float16 and torch.bfloat16 dtypes, but" + f" the current dype in {cls.__name__} is {torch_dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator," + ' or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_3", torch_dtype=torch.float16)`' + ) + + # The check `torch.empty(0).device.type != "cuda"` is needed as the model may be initialized after `torch.set_default_device` has been called, + # or the model may be initialized under the context manager `with torch.device("cuda"):`. + if check_device_map and device_map is None and torch.empty(0).device.type != "cuda": + if torch.cuda.is_available(): + logger.warning_once( + "You are attempting to use Flash Attention 3.0 with a model not initialized on GPU. Make sure to move the model to GPU" + " after initializing it on CPU with `model.to('cuda')`." + ) + else: + raise ValueError( + "You are attempting to use Flash Attention 3.0 with a model not initialized on GPU and with no GPU available. " + "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map " + "or initialising the model on CPU and then moving it to GPU." + ) + elif ( + check_device_map + and device_map is not None + and isinstance(device_map, dict) + and ("cpu" in device_map.values() or "disk" in device_map.values()) + ): + raise ValueError( + "You are attempting to use Flash Attention 3.0 with a model dispatched on CPU or disk. This is not supported. Please make sure to " + "initialise the model on a GPU by passing a device_map that contains only GPU devices as keys." + ) + if not hard_check_only: + config._attn_implementation = "flash_attention_3" + return config + @classmethod def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> PretrainedConfig: """ diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index 5aad7b23a8a672..f6fd3e911d95d0 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -35,6 +35,7 @@ add_start_docstrings_to_model_forward, is_accelerate_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, ) @@ -56,6 +57,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -279,9 +283,92 @@ def forward( return outputs +class BarkSelfFlashAttention3(BarkSelfAttention): + """ + Bark flash attention module. This module inherits from `BarkSelfAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _split_heads(self, tensor, num_heads, attn_head_size): + """ + Splits hidden_size dim into attn_head_size and num_heads + """ + new_shape = tensor.size()[:-1] + (num_heads, attn_head_size) + tensor = tensor.view(new_shape) + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim - (batch, seq_length, head, head_features) + return tensor + + def _merge_heads(self, tensor, num_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into hidden_size + """ + # re-assemble all head outputs side by side + # (batch, seq_len, num_heads, attn_head_size) -> (batch, seq_len, num_heads*attn_head_size) + tensor = tensor.view(tensor.size()[:-2] + (num_heads * attn_head_size,)) + return tensor + + def forward( + self, + hidden_states, + attention_mask=None, + past_key_values=None, + head_mask=None, + use_cache=False, + output_attentions=False, + ): + batch_size, query_len, _ = hidden_states.size() + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + query, key, value = self.att_proj(hidden_states).split(self.embed_dim, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if past_key_values is not None: + # (batch, head, seq_length, head_features) -> (batch, seq_length, head, head_features) + past_key = past_key_values[0].transpose(1, 2) + past_value = past_key_values[1].transpose(1, 2) + # and merge on seq_length + key = torch.cat((past_key, key), dim=1) + value = torch.cat((past_value, value), dim=1) + + if use_cache is True: + # (batch, head, seq_length, head_features) + present = (key.transpose(1, 2), value.transpose(1, 2)) + else: + present = None + + attn_output = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_len, + is_causal=self.is_causal, + ) + + attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + attn_weights = None + outputs += (attn_weights,) + + return outputs + + BARK_ATTENTION_CLASSES = { "eager": BarkSelfAttention, "flash_attention_2": BarkSelfFlashAttention2, + "flash_attention_3": BarkSelfFlashAttention3, } @@ -376,6 +463,7 @@ class BarkPreTrainedModel(PreTrainedModel): config_class = BarkConfig supports_gradient_checkpointing = False _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module): """Initialize the weights.""" @@ -561,6 +649,7 @@ def __init__(self, config): self.layers = nn.ModuleList([BarkBlock(config, is_causal=True) for _ in range(config.num_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self.layernorm_final = BarkLayerNorm(config.hidden_size, bias=config.bias) @@ -702,7 +791,7 @@ def forward( if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None else: attention_mask = attention_mask.view(batch_size, -1) @@ -1157,6 +1246,7 @@ def __init__(self, config): self.layers = nn.ModuleList([BarkBlock(config, is_causal=False) for _ in range(config.num_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self.layernorm_final = nn.LayerNorm(config.hidden_size) @@ -1341,7 +1431,7 @@ def forward( if attention_mask is not None: if batch_size <= 0: raise ValueError("batch_size has to be defined and > 0") - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None else: # [bsz, to_seq_length] -> [bsz, 1, 1, to_seq_length] @@ -1811,3 +1901,39 @@ def _check_and_enable_flash_attn_2( config.coarse_acoustics_config._attn_implementation = config._attn_implementation config.fine_acoustics_config._attn_implementation = config._attn_implementation return config + + @classmethod + def _check_and_enable_flash_attn_3( + cls, + config, + torch_dtype: Optional[torch.dtype] = None, + device_map: Optional[Union[str, Dict[str, int]]] = None, + hard_check_only: bool = False, + check_device_map: bool = False, + ): + """ + `_check_and_enable_flash_attn_3` originally don't expand flash attention enabling to the model + sub-configurations. We override the original method to make sure that Bark sub-models are using Flash Attention + if necessary. + + If you don't know about Flash Attention, check out the official repository of flash attention: + https://github.com/Dao-AILab/flash-attention + + For using Flash Attention 1.0 you can do it directly via the `BetterTransformer` API, have a look at this + specific section of the documentation to learn more about it: + https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#decoder-models + + The method checks if the current setup is compatible with Flash Attention as it requires the model to be in + half precision and not ran on CPU. + + If all checks pass and `hard_check_only` is False, the method will set the config attribute `_attn_implementation` to "flash_attention_3" so that the model + can initialize the correct attention module + """ + config = super()._check_and_enable_flash_attn_3( + config, torch_dtype, device_map, hard_check_only=hard_check_only, check_device_map=check_device_map + ) + + config.semantic_config._attn_implementation = config._attn_implementation + config.coarse_acoustics_config._attn_implementation = config._attn_implementation + config.fine_acoustics_config._attn_implementation = config._attn_implementation + return config diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index fa928d05caa89b..2e31469f9c3c49 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -47,6 +47,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -57,6 +58,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -413,6 +416,125 @@ def forward( return attn_output, attn_weights, past_key_value +class BartFlashAttention3(BartAttention): + """ + Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # BartFlashAttention3 attention does not support output_attentions + if output_attentions: + raise ValueError("BartFlashAttention3 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class BartSdpaAttention(BartAttention): def forward( self, @@ -523,6 +645,7 @@ def forward( "eager": BartAttention, "sdpa": BartSdpaAttention, "flash_attention_2": BartFlashAttention2, + "flash_attention_3": BartFlashAttention3, } @@ -748,6 +871,7 @@ class BartPreTrainedModel(PreTrainedModel): _no_split_modules = [r"BartEncoderLayer", r"BartDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): @@ -979,6 +1103,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No ) self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) @@ -1067,7 +1192,7 @@ def forward( # expand attention_mask if attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None elif self._use_sdpa and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA, fall back to @@ -1163,6 +1288,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No ) self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -1283,7 +1409,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: @@ -1303,7 +1429,7 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index d32ab95f51c7da..bdc86c34fb6bec 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -790,7 +790,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 23334311ca9511..d1c33df2b42fd7 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -538,6 +539,112 @@ def forward( return attn_output, attn_weights, past_key_value +class ChameleonFlashAttention3(ChameleonAttention): + """ + Chameleon flash attention module. This module inherits from `ChameleonAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.reshape(-1, self.num_heads, self.head_dim) + query_states = self.q_norm(query_states) + + key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim) + key_states = self.k_norm(key_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. + # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (ChameleonRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class ChameleonSdpaAttention(ChameleonAttention): """ Chameleon attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -634,6 +741,7 @@ def forward( CHAMELEON_ATTENTION_CLASSES = { "eager": ChameleonAttention, "flash_attention_2": ChameleonFlashAttention2, + "flash_attention_3": ChameleonFlashAttention3, "sdpa": ChameleonSdpaAttention, } @@ -1155,6 +1263,7 @@ class ChameleonPreTrainedModel(PreTrainedModel): _no_split_modules = ["ChameleonDecoderLayer", "ChameleonSwinDecoderLayer"] _skip_keys_device_placement = ["past_key_values", "causal_mask"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_quantized_cache = True _supports_cache_class = True @@ -1433,7 +1542,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 64eb027e9e220c..3d120d3fb2b119 100644 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -33,6 +33,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -43,6 +44,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -432,6 +435,82 @@ def forward( return attn_output, attn_weights +class CLIPFlashAttention3(CLIPAttention): + """ + CLIPAttention flash attention module. This module inherits from `CLIPAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + causal_attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + output_attentions = False + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=causal_attention_mask is not None, + ) + + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + class CLIPSdpaAttention(CLIPAttention): """ SDPA attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -509,6 +588,7 @@ def forward( "eager": CLIPAttention, "sdpa": CLIPSdpaAttention, "flash_attention_2": CLIPFlashAttention2, + "flash_attention_3": CLIPFlashAttention3, } @@ -588,6 +668,7 @@ class CLIPPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_sdpa = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module): """Initialize the weights""" @@ -857,6 +938,7 @@ def __init__(self, config: CLIPTextConfig): # For attention mask, it differs between `flash_attention_2` and other attention implementations self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) @@ -894,7 +976,7 @@ def forward( ) # expand attention_mask - if attention_mask is not None and not self._use_flash_attention_2: + if attention_mask is not None and not self._use_flash_attention_2 and not self._use_flash_attention_3: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 46eea43e1285f8..2b3a953601a682 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -633,7 +633,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index ae84a9ec2d1a43..c8cb6103f86f39 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -44,6 +44,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -55,6 +56,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -539,6 +543,118 @@ def forward( return attn_output, attn_weights, past_key_value +class CohereFlashAttention3(CohereAttention): + """ + Cohere flash attention module. This module inherits from `CohereAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim) + if self.use_qk_norm: + query_states = self.q_norm(query_states) + key_states = self.k_norm(key_states) + + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; position_ids needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (CohereLayerNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class CohereSdpaAttention(CohereAttention): """ Cohere attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -645,6 +761,7 @@ def forward( COHERE_ATTENTION_CLASSES = { "eager": CohereAttention, "flash_attention_2": CohereFlashAttention2, + "flash_attention_3": CohereFlashAttention3, "sdpa": CohereSdpaAttention, } @@ -751,6 +868,7 @@ class CoherePreTrainedModel(PreTrainedModel): _no_split_modules = ["CohereDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1005,7 +1123,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index dd2a676b26c27f..8b45b788a9a253 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -40,6 +40,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_peft_available, logging, @@ -50,6 +51,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -608,6 +612,125 @@ def forward( return attn_output, attn_weights, past_key_value +class Data2VecAudioFlashAttention3(Data2VecAudioAttention): + """ + Data2VecAudio flash attention module. This module inherits from `Data2VecAudioAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Data2VecAudioFlashAttention3 attention does not support output_attentions + if output_attentions: + raise ValueError("Data2VecAudioFlashAttention3 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class Data2VecAudioSdpaAttention(Data2VecAudioAttention): # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Data2VecAudio def forward( @@ -719,6 +842,7 @@ def forward( "eager": Data2VecAudioAttention, "sdpa": Data2VecAudioSdpaAttention, "flash_attention_2": Data2VecAudioFlashAttention2, + "flash_attention_3": Data2VecAudioFlashAttention3, } @@ -794,6 +918,7 @@ def __init__(self, config): self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -810,7 +935,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -931,6 +1056,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 43bac44ba1be20..215a6b0862f493 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -30,6 +30,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -40,6 +41,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "DbrxConfig" @@ -484,6 +488,117 @@ def forward( return attn_output, attn_weights, past_key_value +class DbrxFlashAttention3(DbrxAttention): + """Dbrx flash attention module. + + This module inherits from `DbrxAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it + calls the public API of flash attention. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Any, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.") + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.Wqkv(hidden_states) + if self.clip_qkv is not None: + qkv_states = qkv_states.clamp(min=-self.clip_qkv, max=self.clip_qkv) + + query_states, key_states, value_states = qkv_states.split( + [ + self.hidden_size, + self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim, + ], + dim=2, + ) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout + # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = query_states.dtype + + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be " + + "related to the fact you have upcasted embedding or layer norm layers in " + + f"float32. We will cast back the input in {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class DbrxSdpaAttention(DbrxAttention): """ Dbrx attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -582,6 +697,7 @@ def forward( DBRX_ATTENTION_CLASSES = { "eager": DbrxAttention, "flash_attention_2": DbrxFlashAttention2, + "flash_attention_3": DbrxFlashAttention3, "sdpa": DbrxSdpaAttention, } @@ -882,6 +998,7 @@ class DbrxPreTrainedModel(PreTrainedModel): _no_split_modules = ["DbrxBlock"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1164,7 +1281,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index e80e3c41d22cb6..a91cd4f8cb6ea1 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -44,6 +44,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -54,6 +55,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "distilbert-base-uncased" @@ -329,6 +332,93 @@ def reshape(x: torch.Tensor) -> torch.Tensor: return (attn_output,) +class DistilBertFlashAttention3(MultiHeadSelfAttention): + """ + DistilBert flash attention module. This module inherits from `MultiHeadSelfAttention` as the weights of the module + stays untouched. The only required change would be on the forward pass where it needs to correctly call the public + API of flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, ...]: + """ + Parameters: + query: torch.tensor(bs, seq_length, dim) + key: torch.tensor(bs, seq_length, dim) + value: torch.tensor(bs, seq_length, dim) + mask: torch.tensor(bs, seq_length) + + Returns: + weights: torch.tensor(bs, n_heads, seq_length, seq_length) Attention weights context: torch.tensor(bs, + seq_length, dim) Contextualized layer. Optional: only if `output_attentions=True` + """ + batch_size, q_length, dim = query.size() + + dim_per_head = self.dim // self.n_heads + + def reshape(x: torch.Tensor) -> torch.Tensor: + """separate heads""" + return x.view(batch_size, -1, self.n_heads, dim_per_head) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = reshape(self.q_lin(query)) + key_states = reshape(self.k_lin(key)) + value_states = reshape(self.v_lin(value)) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_lin.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_weights = _flash_attention_3_forward( + query_states, + key_states, + value_states, + mask, + q_length, + is_causal=self.is_causal, + ) + + attn_weights_reshaped = attn_weights.reshape(batch_size, q_length, self.n_heads * dim_per_head) + attn_output = self.out_lin(attn_weights_reshaped) + + if output_attentions: + return (attn_output, attn_weights) + else: + return (attn_output,) + + class FFN(nn.Module): def __init__(self, config: PretrainedConfig): super().__init__() @@ -353,6 +443,7 @@ def ff_chunk(self, input: torch.Tensor) -> torch.Tensor: DISTILBERT_ATTENTION_CLASSES = { "eager": MultiHeadSelfAttention, "flash_attention_2": DistilBertFlashAttention2, + "flash_attention_3": DistilBertFlashAttention3, } @@ -503,6 +594,7 @@ class DistilBertPreTrainedModel(PreTrainedModel): base_model_prefix = "distilbert" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module: nn.Module): """Initialize the weights.""" @@ -589,6 +681,7 @@ def __init__(self, config: PretrainedConfig): self.embeddings = Embeddings(config) # Embeddings self.transformer = Transformer(config) # Encoder self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" # Initialize weights and apply final processing self.post_init() @@ -694,7 +787,7 @@ def forward( embeddings = self.embeddings(input_ids, inputs_embeds) # (bs, seq_length, dim) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: if attention_mask is None: diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index 73e8806352ebbd..ac511ea145a58f 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -43,6 +43,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, ) @@ -55,6 +56,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -688,6 +692,110 @@ def forward( return attn_output, layer_past, attn_weights +class FalconFlashAttention3(FalconAttention): + """ + Falcon flash attention module. This module inherits from `FalconAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: Optional[torch.Tensor], + attention_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + layer_past: Optional[Cache] = None, + head_mask: Optional[torch.Tensor] = None, + use_cache: bool = False, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads + # 3 x [batch_size, seq_length, num_heads, head_dim] + (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) + + batch_size, query_length, _, _ = query_layer.shape + + query_layer = query_layer.transpose(1, 2).reshape(batch_size, self.num_heads, query_length, self.head_dim) + key_layer = key_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) + value_layer = value_layer.transpose(1, 2).reshape(batch_size, num_kv_heads, query_length, self.head_dim) + + if alibi is None: + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_layer, position_ids) + else: + cos, sin = position_embeddings + query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin) + + if layer_past is not None: + cache_kwargs = {"cache_position": cache_position} + if alibi is None: + cache_kwargs.update({"sin": sin, "cos": cos}) + key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_layer = query_layer.transpose(1, 2) + key_layer = key_layer.transpose(1, 2) + value_layer = value_layer.transpose(1, 2) + + if alibi is not None: + raise ValueError("`alibi` is not supported when `use_flash_attn` is True") + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_layer.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.query_key_value.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_layer = query_layer.to(target_dtype) + key_layer = key_layer.to(target_dtype) + value_layer = value_layer.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_layer, + key_layer, + value_layer, + attention_mask, + query_length, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + attn_output = self.dense(attn_weights) + + if not output_attentions: + attn_weights = None + + return attn_output, layer_past, attn_weights + + class FalconMLP(nn.Module): def __init__(self, config: FalconConfig): super().__init__() @@ -708,6 +816,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: "eager": FalconAttention, "sdpa": FalconAttention, # FalconAttention originally implemented both a forward with & without SDPA "flash_attention_2": FalconFlashAttention2, + "flash_attention_3": FalconFlashAttention3, } @@ -909,6 +1018,7 @@ class FalconPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["FalconDecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -971,6 +1081,7 @@ def __init__(self, config: FalconConfig): # Transformer blocks self.h = nn.ModuleList([FalconDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" # Final Layer Norm @@ -1157,7 +1268,10 @@ def _update_causal_mask( # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gemma/diff_gemma.py b/src/transformers/models/gemma/diff_gemma.py index fcdb0f0b3d7d8f..53b656282eb46a 100644 --- a/src/transformers/models/gemma/diff_gemma.py +++ b/src/transformers/models/gemma/diff_gemma.py @@ -34,6 +34,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import CausalLMOutputWithPast from ...pytorch_utils import ALL_LAYERNORM_LAYERS @@ -441,6 +442,103 @@ def forward( return attn_output, attn_weights, past_key_value +# TODO felix: does this inheritance really work out in the end to GemmaFlashAttention2 inheriting form GemmaAttention? +class GemmaFlashAttention3(LlamaFlashAttention2): + """ + Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (GemmaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class GemmaModel(LlamaModel): def forward( self, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index b14e0a4b3d8ca5..b327f0dd31f2da 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -467,6 +468,106 @@ def forward( return attn_output, attn_weights, past_key_value +class GemmaFlashAttention3(GemmaAttention): + """ + Gemma flash attention module. This module inherits from `GemmaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (GemmaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class GemmaSdpaAttention(GemmaAttention): """ Gemma attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -558,6 +659,7 @@ def forward( GEMMA_ATTENTION_CLASSES = { "eager": GemmaAttention, "flash_attention_2": GemmaFlashAttention2, + "flash_attention_3": GemmaFlashAttention3, "sdpa": GemmaSdpaAttention, } @@ -665,6 +767,7 @@ class GemmaPreTrainedModel(PreTrainedModel): _no_split_modules = ["GemmaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -932,7 +1035,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gemma2/diff_gemma2.py b/src/transformers/models/gemma2/diff_gemma2.py index 30f371a1b61267..235e43c1ee0539 100644 --- a/src/transformers/models/gemma2/diff_gemma2.py +++ b/src/transformers/models/gemma2/diff_gemma2.py @@ -34,12 +34,15 @@ from ...cache_utils import Cache from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast -from ...utils import is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, logging +from ...utils import is_flash_attn_2_available, is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -177,6 +180,101 @@ def forward( return attn_output, attn_weights, past_key_value +class Gemma2FlashAttention3(Gemma2Attention): + """ + Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (Gemma2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + ########### ONLY DIFFERENCE IS WE USE SLIDING AND PASS THE SOFTMAX SCALING + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + softmax_scale=self.scaling, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class Gemma2SdpaAttention(Gemma2Attention): """ Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -439,7 +537,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 1909ef78501559..5665ce23548f10 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -39,6 +39,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, @@ -51,6 +52,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -442,6 +446,110 @@ def forward( return attn_output, attn_weights, past_key_value +class Gemma2FlashAttention3(Gemma2Attention): + """ + Gemma2 flash attention module. This module inherits from `Gemma2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "sliding_window": self.sliding_window, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if attention_mask is not None: + seq_len = attention_mask.shape[1] + key_states = key_states[:, :, :seq_len] + value_states = value_states[:, :, :seq_len] + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (Gemma2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + softmax_scale=self.scaling, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class Gemma2SdpaAttention(Gemma2Attention): """ Gemma2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -535,6 +643,7 @@ def forward( GEMMA2_ATTENTION_CLASSES = { "eager": Gemma2Attention, "flash_attention_2": Gemma2FlashAttention2, + "flash_attention_3": Gemma2FlashAttention3, "sdpa": Gemma2SdpaAttention, } @@ -568,7 +677,10 @@ def forward( ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: if self.is_sliding and attention_mask is not None: # efficient SDPA and no padding # Flash-attn is a 2D tensor - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if past_key_value is not None: # when decoding attention_mask = attention_mask[:, -self.sliding_window :] else: @@ -641,6 +753,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Gemma2DecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = False @@ -919,7 +1032,10 @@ def _update_causal_mask( # So we will pass in attention mask as is in any case, not only when ther's padding. Then we'll use its shape # to cut out keys/values trailing 0 used in static cache. This workaround should be compile compatible # as it doesn't cause dynamic control issues. - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): return attention_mask dtype, device = input_tensor.dtype, input_tensor.device @@ -1130,6 +1246,7 @@ def prepare_inputs_for_generation( isinstance(past_key_values, HybridCache) and attention_mask.ndim == 2 and not self.config._attn_implementation == "flash_attention_2" + and not self.config._attn_implementation == "flash_attention_3" ): if model_inputs["inputs_embeds"] is not None: batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 8dfbfb9064444d..91e75a1d6f4ede 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -45,6 +45,7 @@ add_start_docstrings_to_model_forward, get_torch_version, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -56,6 +57,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -454,6 +457,109 @@ def forward( return outputs +class GPT2FlashAttention3(GPT2Attention): + """ + GPT2 flash attention module. This module inherits from `GPT2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: Optional[Tuple[torch.FloatTensor]], + layer_past: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: + bsz, _, _ = hidden_states.size() + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn"): + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2) + attention_mask = encoder_attention_mask + else: + query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = torch.cat((past_key, key), dim=-2) + value = torch.cat((past_value, value), dim=-2) + + present = None + if use_cache is True: + present = (key, value) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_length, + is_causal=self.is_causal, + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + class GPT2SdpaAttention(GPT2Attention): """ GPT2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -578,7 +684,12 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl return hidden_states -GPT2_ATTENTION_CLASSES = {"eager": GPT2Attention, "flash_attention_2": GPT2FlashAttention2, "sdpa": GPT2SdpaAttention} +GPT2_ATTENTION_CLASSES = { + "eager": GPT2Attention, + "flash_attention_2": GPT2FlashAttention2, + "flash_attention_3": GPT2FlashAttention3, + "sdpa": GPT2SdpaAttention, +} class GPT2Block(nn.Module): @@ -674,6 +785,7 @@ class GPT2PreTrainedModel(PreTrainedModel): _no_split_modules = ["GPT2Block"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__(self, *inputs, **kwargs): @@ -1031,7 +1143,7 @@ def forward( # Attention mask. _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None - if self._attn_implementation == "flash_attention_2": + if self._attn_implementation == "flash_attention_2" or self._attn_implementation == "flash_attention_3": attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif _use_sdpa: attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -1068,7 +1180,10 @@ def forward( encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( mask=encoder_attention_mask, dtype=inputs_embeds.dtype, tgt_len=input_shape[-1] ) - elif not self._attn_implementation == "flash_attention_2": + elif ( + not self._attn_implementation == "flash_attention_2" + and not self._attn_implementation == "flash_attention_3" + ): encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_attention_mask = None diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index 0f927a72469dc9..8233a2a459e4c1 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -36,6 +36,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, ) @@ -45,6 +46,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -396,6 +400,122 @@ def forward( return outputs # a, present, (attentions) +class GPTBigCodeFlashAttention3(GPTBigCodeAttention): + """ + GPTBigCode flash attention module. This module inherits from `GPTBigCodeAttention` as the weights of the module + stays untouched. The only required change would be on the forward pass where it needs to correctly call the public + API of flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + layer_past: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + ) -> Union[ + Tuple[torch.Tensor, Optional[torch.Tensor]], + Tuple[torch.Tensor, Optional[torch.Tensor], Tuple[torch.Tensor, ...]], + ]: + if encoder_hidden_states is not None: + if not hasattr(self, "q_attn") or not self.is_cross_attention: + raise ValueError( + "If class is used as cross attention, the weights `q_attn` have to be defined. " + "Please make sure to instantiate class with `GPTBigCodeAttention(..., is_cross_attention=True)`." + ) + + query = self.q_attn(hidden_states) + key_value = self.c_attn(encoder_hidden_states) + attention_mask = encoder_attention_mask + elif self.multi_query: + query, key_value = self.c_attn(hidden_states).split((self.embed_dim, 2 * self.kv_dim), dim=2) + else: + # Note: We split as (self.num_heads, 3, self.head_dim) instead of (3, self.num_heads, self.head_dim), + # i.e., the memory layout is not the same as GPT2. + # This makes the concatenation with past_key_value more efficient. + query, key_value = ( + self.c_attn(hidden_states) + .view(*hidden_states.shape[:2], self.num_heads, 3 * self.head_dim) + .transpose(1, 2) + .split((self.head_dim, 2 * self.head_dim), dim=3) + ) + + if layer_past is not None: + key_value = torch.cat((layer_past, key_value), dim=-2) + present = key_value if use_cache else None + + key, value = key_value.split((self.head_dim, self.head_dim), dim=-1) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + if self.multi_query: + batch_size, query_length, _ = query.shape + query = query.reshape(batch_size, query_length, self.num_heads, self.head_dim) + key = key.unsqueeze(2) + value = value.unsqueeze(2) + else: + query_length = query.shape[2] + batch_size, _, tgt, _ = key.shape + query = query.transpose(1, 2).reshape(batch_size, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) + value = value.transpose(1, 2).reshape(batch_size, tgt, self.num_heads, self.head_dim) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.c_attn.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_length, + is_causal=self.is_causal, + ) + + attn_weights_reshaped = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) + attn_output = self.c_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + + if output_attentions: + if self.multi_query: + # Transpose to return weights in the usual format (batch_size, num_heads, query_length, key_length) + attn_weights_reshaped = attn_weights_reshaped.transpose(1, 2) + else: + attn_weights_reshaped = None + + outputs += (attn_weights_reshaped,) + + return outputs # a, present, (attentions) + + class GPTBigCodeSdpaAttention(GPTBigCodeAttention): def _attn(self, query, key, value, attention_mask=None, head_mask=None): if head_mask is not None: @@ -561,6 +681,7 @@ def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.Fl GPTBIGCODE_ATTENTION_CLASSES = { "eager": GPTBigCodeAttention, "flash_attention_2": GPTBigCodeFlashAttention2, + "flash_attention_3": GPTBigCodeFlashAttention3, "sdpa": GPTBigCodeSdpaAttention, } @@ -666,6 +787,7 @@ class GPTBigCodePreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTBigCodeBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__(self, *inputs, **kwargs): @@ -810,6 +932,7 @@ def __init__(self, config): self._use_sdpa = config._attn_implementation == "sdpa" self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" # Initialize weights and apply final processing self.post_init() @@ -891,7 +1014,7 @@ def forward( key_length = past_length + query_length self_attention_mask = self.bias[None, key_length - query_length : key_length, :key_length] - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask.bool() if (attention_mask is not None and 0 in attention_mask) else None encoder_attention_mask = ( diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 72590862b749f0..571d15ae4f3cd1 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -41,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torch_fx_available, logging, @@ -51,6 +52,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. @@ -426,9 +429,102 @@ def forward( return outputs +class GPTNeoFlashAttention3(GPTNeoSelfAttention): + """ + GPTNeo flash attention module. This module inherits from `GPTNeoSelfAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states, + attention_mask=None, + layer_past=None, + head_mask=None, + use_cache=False, + output_attentions=False, + cache_position=None, + ): + bsz, _, _ = hidden_states.size() + + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_heads, self.head_dim) + key = self._split_heads(key, self.num_heads, self.head_dim) + value = self._split_heads(value, self.num_heads, self.head_dim) + + if layer_past is not None: + cache_kwargs = {"cache_position": cache_position} + key, value = layer_past.update(key, value, self.layer_id, cache_kwargs) + + query_length = query.shape[2] + tgt_len = key.shape[2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query = query.transpose(1, 2).view(bsz, query_length, self.num_heads, self.head_dim) + key = key.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value = value.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + attention_mask = attention_mask[:, :, :, : key.shape[-2]] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + if query.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_length, + softmax_scale=1.0, + is_causal=self.is_causal, + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, layer_past) + if output_attentions: + outputs += (attn_weights_reshaped,) + + return outputs + + GPT_NEO_ATTENTION_CLASSES = { "eager": GPTNeoSelfAttention, "flash_attention_2": GPTNeoFlashAttention2, + "flash_attention_3": GPTNeoFlashAttention3, } @@ -550,6 +646,7 @@ class GPTNeoPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTNeoBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = False # TODO: needs a HybridCache @@ -847,7 +944,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index e88302efa7bb04..d935ac65a3a722 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -43,6 +43,7 @@ from ...utils import ( get_torch_version, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, ) @@ -52,6 +53,10 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "trl-internal-testing/tiny-random-GPTNeoXForCausalLM" @@ -125,6 +130,7 @@ class GPTNeoXPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTNeoXLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True @@ -459,6 +465,100 @@ def forward( return outputs +class GPTNeoXFlashAttention3(GPTNeoXAttention): + """ + GPTNeoX flash attention module. This module inherits from `GPTNeoXAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.FloatTensor, + attention_mask: torch.FloatTensor, + position_ids: torch.LongTensor, + head_mask: Optional[torch.FloatTensor] = None, + layer_past: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + # Apply attention-specific projections and rope + query, key, value, present = self._attn_projections_and_rope( + hidden_states=hidden_states, + position_ids=position_ids, + layer_past=layer_past, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + ) + + query_length = query.shape[-2] + + # GPT-neo-X casts query and key in fp32 to apply rotary embedding in full precision + target_dtype = value.dtype + if query.dtype != target_dtype: + query = query.to(target_dtype) + if key.dtype != target_dtype: + key = key.to(target_dtype) + + # Permute to get the expected shape for Flash Attention + query = query.permute(0, 2, 1, 3) + key = key.permute(0, 2, 1, 3) + value = value.permute(0, 2, 1, 3) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 / bfloat16 just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + input_dtype = query.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.query_key_value.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + # Compute attention + attn_weights = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_length, + softmax_scale=self.norm_factor, + is_causal=self.is_causal, + ) + + # Reshape outputs + attn_output = attn_weights.reshape( + attn_weights.shape[0], attn_weights.shape[1], self.num_attention_heads * self.head_size + ) + attn_output = self.dense(attn_output) + + outputs = (attn_output, layer_past) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + class GPTNeoXSdpaAttention(GPTNeoXAttention): """ GPTNeoX attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -727,6 +827,7 @@ def forward(self, hidden_states): GPT_NEOX_ATTENTION_CLASSES = { "eager": GPTNeoXAttention, "flash_attention_2": GPTNeoXFlashAttention2, + "flash_attention_3": GPTNeoXFlashAttention3, "sdpa": GPTNeoXSdpaAttention, } @@ -1044,7 +1145,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index bf832195b4efc3..801a6f0d778141 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -748,7 +748,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index bd7ce5696fa077..7a4f69e41d9bba 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -38,6 +38,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torch_fx_proxy, logging, @@ -50,6 +51,10 @@ from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "hf-internal-testing/tiny-random-gptj" @@ -457,9 +462,145 @@ def forward( return outputs +class GPTJFlashAttention3(GPTJAttention): + """ + GPTJ flash attention module. This module inherits from `GPTJAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.FloatTensor, + layer_past: Optional[Cache] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: + query = self.q_proj(hidden_states) + key = self.k_proj(hidden_states) + value = self.v_proj(hidden_states) + + query = self._split_heads(query, self.num_attention_heads, self.head_dim, True) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, True) + value = self._split_heads(value, self.num_attention_heads, self.head_dim, False) + + if is_torch_fx_proxy(position_ids) or torch.jit.is_tracing(): + # The logic to conditionally copy to GPU could not be traced, so we do this + # every time in the torch.fx case + embed_positions = get_embed_positions(self.embed_positions, position_ids) + else: + embed_positions = self._get_embed_positions(position_ids) + + repeated_position_ids = position_ids.unsqueeze(-1).repeat(1, 1, embed_positions.shape[-1]) + sincos = torch.gather(embed_positions, 1, repeated_position_ids) + sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1) + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + k_rot = apply_rotary_pos_emb(k_rot, sin, cos) + q_rot = apply_rotary_pos_emb(q_rot, sin, cos) + + key = torch.cat([k_rot, k_pass], dim=-1) + query = torch.cat([q_rot, q_pass], dim=-1) + else: + key = apply_rotary_pos_emb(key, sin, cos) + query = apply_rotary_pos_emb(query, sin, cos) + + # tanspose to have the desired shape + # before transpose: batch_size x seq_length x num_attention_heads x head_dim + # after transpose: batch_size x num_attention_heads x seq_length x head_dim + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + # value: batch_size x num_attention_heads x seq_length x head_dim + + if layer_past is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_dim, + "cache_position": cache_position, + } + key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs) + + # The Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we need to keep the original shape for query and key, and reshape value + # to have the correct shape. + key = key.permute(0, 2, 1, 3).contiguous() + query = query.permute(0, 2, 1, 3).contiguous() + value = value.permute(0, 2, 1, 3).contiguous() + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query = query.to(target_dtype) + key = key.to(target_dtype) + value = value.to(target_dtype) + + query_length = query.shape[1] + + # Compute attention + attn_weights = _flash_attention_3_forward( + query, + key, + value, + attention_mask, + query_length, + is_causal=self.is_causal, + ) + + # Reshape outputs + attn_output = attn_weights.reshape( + attn_weights.shape[0], attn_weights.shape[1], attn_weights.shape[2] * attn_weights.shape[3] + ) + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, layer_past) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + GPTJ_ATTENTION_CLASSES = { "eager": GPTJAttention, "flash_attention_2": GPTJFlashAttention2, + "flash_attention_3": GPTJFlashAttention3, } @@ -540,6 +681,7 @@ class GPTJPreTrainedModel(PreTrainedModel): _no_split_modules = ["GPTJBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True @@ -720,6 +862,7 @@ def __init__(self, config): self.post_init() self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" @add_start_docstrings(PARALLELIZE_DOCSTRING) def parallelize(self, device_map=None): @@ -942,7 +1085,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 876f5ed2a7c8da..d8929470a912f6 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -23,6 +23,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -453,6 +454,102 @@ def forward( return attn_output, attn_weights, past_key_value +class GraniteFlashAttention3(GraniteAttention): + """ + Granite flash attention module. This module inherits from `GraniteAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.45 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (GraniteRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + softmax_scale=self.scaling, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class GraniteSdpaAttention(GraniteAttention): """ Granite attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -547,6 +644,7 @@ def forward( GRANITE_ATTENTION_CLASSES = { "eager": GraniteAttention, "flash_attention_2": GraniteFlashAttention2, + "flash_attention_3": GraniteFlashAttention3, "sdpa": GraniteSdpaAttention, } @@ -662,6 +760,7 @@ class GranitePreTrainedModel(PreTrainedModel): _no_split_modules = ["GraniteDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -934,7 +1033,10 @@ def _update_causal_mask( # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index da79c2894877b4..70ec00e809635d 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -32,6 +32,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -42,6 +43,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -678,6 +681,125 @@ def forward( return attn_output, attn_weights, past_key_value +class HubertFlashAttention3(HubertAttention): + """ + Hubert flash attention module. This module inherits from `HubertAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # HubertFlashAttention3 attention does not support output_attentions + if output_attentions: + raise ValueError("HubertFlashAttention3 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class HubertSdpaAttention(HubertAttention): # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Hubert def forward( @@ -789,6 +911,7 @@ def forward( "eager": HubertAttention, "sdpa": HubertSdpaAttention, "flash_attention_2": HubertFlashAttention2, + "flash_attention_3": HubertFlashAttention3, } @@ -936,6 +1059,7 @@ def __init__(self, config): self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -952,7 +1076,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1024,6 +1148,7 @@ def __init__(self, config): ) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1040,7 +1165,7 @@ def forward( # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1112,6 +1237,7 @@ class HubertPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 0d26c4fc65de78..714dfbc29053ac 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1475,7 +1475,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 08ada424ea77b4..b005c81c112d1f 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -32,6 +32,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -44,6 +45,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -363,9 +366,98 @@ def forward( return attn_output, attn_weights +class Idefics2VisionFlashAttention3(Idefics2VisionAttention): + """ + Idefics2Vision flash attention module. This module inherits from `Idefics2VisionAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (Idefics2VisionRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + IDEFICS_VISION_ATTENTION_CLASSES = { "eager": Idefics2VisionAttention, "flash_attention_2": Idefics2VisionFlashAttention2, + "flash_attention_3": Idefics2VisionFlashAttention3, } @@ -582,6 +674,7 @@ def __init__(self, config: Idefics2VisionConfig): self.encoder = Idefics2Encoder(config) self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def get_input_embeddings(self): return self.embeddings @@ -623,7 +716,7 @@ def forward( # avoiding passing the attention_mask, which is equivalent to attending to the full sequence if not torch.any(~patch_attention_mask): patch_attention_mask = None - elif not self._use_flash_attention_2: + elif not self._use_flash_attention_2 and not self._use_flash_attention_3: patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) encoder_outputs = self.encoder( @@ -909,9 +1002,125 @@ def forward( return attn_output, attn_weights, past_key_value +class Idefics2PerceiverFlashAttention3(Idefics2PerceiverAttention): + """ + Idefics2 flash attention module. This module inherits from `Idefics2PerceiverAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + latents: torch.Tensor, + context: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = latents.size() + kv_seq_len = q_len + context.size()[1] + + # Query, Key, Value Projections --> Note that in Flamingo, latents are *concatenated* with context prior to attn! + # Note: This results in queries w/ `seq = n_latents`, and keys, values with `seq = len(context) + n_latents` + query_states = self.q_proj(latents) + key_states = self.k_proj(torch.cat([context, latents], dim=-2)) + value_states = self.v_proj(torch.cat([context, latents], dim=-2)) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + if hasattr(self.config, "sliding_window") and kv_seq_len > self.config.sliding_window: + slicing_tokens = kv_seq_len - self.config.sliding_window + + past_key = past_key_value[0] + past_value = past_key_value[1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + "past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1," + f" head_dim`), got {past_key.shape}" + ) + + past_key_value = (past_key, past_value) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) if use_cache else None + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + IDEFICS2_PERCEIVER_ATTENTION_CLASSES = { "eager": Idefics2PerceiverAttention, "flash_attention_2": Idefics2PerceiverFlashAttention2, + "flash_attention_3": Idefics2PerceiverFlashAttention3, } @@ -1010,6 +1219,7 @@ def __init__(self, config) -> None: self.norm = Idefics2RMSNorm(self.hidden_size, eps=self.rms_norm_eps) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1025,7 +1235,7 @@ def forward( attention_mask = torch.cat([attention_mask, latent_attention_mask], dim=-1) attention_mask = ( _prepare_4d_attention_mask(attention_mask, latents.dtype, tgt_len=self.n_latents) - if not self._use_flash_attention_2 + if not self._use_flash_attention_2 and not self._use_flash_attention_3 else attention_mask ) @@ -1093,6 +1303,7 @@ class Idefics2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Idefics2VisionAttention", "Idefics2MLP", "Idefics2PerceiverLayer", "Idefics2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True def _init_weights(self, module): @@ -1227,6 +1438,7 @@ def __init__(self, config: Idefics2Config): self.image_token_id = self.config.image_token_id self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self.post_init() diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 60e1670a3c2784..0725d4f058570a 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -48,6 +48,7 @@ from ...utils.import_utils import ( is_causal_conv1d_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_mamba_ssm_available, is_torchdynamo_compiling, @@ -58,6 +59,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + if is_mamba_ssm_available(): from mamba_ssm.ops.selective_scan_interface import mamba_inner_fn, selective_scan_fn @@ -498,6 +502,119 @@ def forward( return attn_output, attn_weights, past_key_value +class JambaFlashAttention3(JambaAttention): + """ + Jamba flash attention module. This module inherits from `JambaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = cache_position[-1] + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = cache_position[0] > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba class JambaSdpaAttention(JambaAttention): """ @@ -584,6 +701,7 @@ def forward( JAMBA_ATTENTION_CLASSES = { "eager": JambaAttention, "flash_attention_2": JambaFlashAttention2, + "flash_attention_3": JambaFlashAttention3, "sdpa": JambaSdpaAttention, } @@ -1121,6 +1239,7 @@ class JambaPreTrainedModel(PreTrainedModel): _no_split_modules = ["JambaAttentionDecoderLayer", "JambaMambaDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True # Note: only supports HybridMambaAttentionDynamicCache _is_stateful = True @@ -1377,7 +1496,10 @@ def forward( ) def _update_causal_mask(self, attention_mask, input_tensor, cache_position): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index c273b021d73664..ab2b46e8e6cb93 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -36,6 +36,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -47,6 +48,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "jetmoe" @@ -804,9 +808,116 @@ def forward( return attn_output, attn_weights, past_key_value, router_logits +class JetMoeFlashAttention3(JetMoeAttention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: Optional[torch.FloatTensor], + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + use_cache: Optional[bool] = False, + output_attentions: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Union[ + Tuple[torch.Tensor, Tuple[torch.Tensor]], + Optional[Tuple[torch.Tensor, Tuple[torch.Tensor], Tuple[torch.Tensor, ...]]], + ]: + """ + Forward pass of the JetMoeAttention module. + + Args: + hidden_states (Optional[torch.FloatTensor]): Input hidden states. + attention_mask (Optional[torch.FloatTensor]): Attention mask. + layer_past (Optional[Tuple[torch.Tensor]]): Past layer state. + use_cache (Optional[bool]): Whether to use cached states. + output_attentions (Optional[bool]): Whether to output attention weights. + cache_position (Optional[torch.LongTensor]): Position of the cache. + + Returns: + Union[Tuple[torch.Tensor, Tuple[torch.Tensor]], Optional[Tuple[...]]]: Tuple containing outputs. + """ + output_attentions = False + bsz, q_len, hidden_size = hidden_states.size() + + # calculate query, key, values + query_states, router_logits, topo_info = self.experts.map(hidden_states) + key_states, value_states = self.kv_proj(hidden_states).chunk(2, dim=-1) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads for top-k attention experts + key_states = key_states.repeat(1, self.top_k, 1, 1) + value_states = value_states.repeat(1, self.top_k, 1, 1) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.kv_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ).to(input_dtype) + + # output projection + attn_output = attn_output.reshape(bsz, q_len, self.top_k, self.kv_projection_size) + attn_output = self.experts.reduce(attn_output, topo_info) + attn_output = attn_output.view(bsz, q_len, hidden_size) # re-assemble all head outputs side by side + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value, router_logits + + JETMOE_ATTENTION_CLASSES = { "eager": JetMoeAttention, "flash_attention_2": JetMoeFlashAttention2, + "flash_attention_3": JetMoeFlashAttention3, "sdpa": JetMoeSdpaAttention, } @@ -879,6 +990,7 @@ class JetMoePreTrainedModel(PreTrainedModel): _no_split_modules = ["JetMoeBlock"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True @@ -1048,7 +1160,11 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + if ( + attention_mask is not None + and (self._attn_implementation == "flash_attention_2" or self._attn_implementation == "flash_attention_3") + and use_cache + ): batch_size = inputs_embeds.shape[0] is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: @@ -1136,7 +1252,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index c7017832b9324c..24b02ec4ea7697 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -29,6 +29,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -569,6 +570,117 @@ def forward( return attn_output, attn_weights, past_key_value +class LlamaFlashAttention3(LlamaAttention): + """ + Llama flash attention module. This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # TODO: get `use_fp8` to here, add attention_kwargs or something + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class LlamaSdpaAttention(LlamaAttention): """ Llama attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -671,6 +783,7 @@ def forward( LLAMA_ATTENTION_CLASSES = { "eager": LlamaAttention, "flash_attention_2": LlamaFlashAttention2, + "flash_attention_3": LlamaFlashAttention3, "sdpa": LlamaSdpaAttention, } @@ -783,6 +896,7 @@ class LlamaPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlamaDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1039,7 +1153,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/llava/modeling_llava.py b/src/transformers/models/llava/modeling_llava.py index eb1c55341b0784..99682d7818fde1 100644 --- a/src/transformers/models/llava/modeling_llava.py +++ b/src/transformers/models/llava/modeling_llava.py @@ -125,6 +125,7 @@ class LlavaPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True def _init_weights(self, module): diff --git a/src/transformers/models/llava_next/modeling_llava_next.py b/src/transformers/models/llava_next/modeling_llava_next.py index c1d1ca8c276d7a..6d6bd7ae356ac3 100644 --- a/src/transformers/models/llava_next/modeling_llava_next.py +++ b/src/transformers/models/llava_next/modeling_llava_next.py @@ -228,6 +228,7 @@ class LlavaNextPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlavaNextVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True def _init_weights(self, module): diff --git a/src/transformers/models/llava_next_video/modeling_llava_next_video.py b/src/transformers/models/llava_next_video/modeling_llava_next_video.py index 589bf346ceeb9e..9c3f296b38b0ec 100644 --- a/src/transformers/models/llava_next_video/modeling_llava_next_video.py +++ b/src/transformers/models/llava_next_video/modeling_llava_next_video.py @@ -272,6 +272,7 @@ class LlavaNextVideoPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlavaNextVideoVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True def _init_weights(self, module): diff --git a/src/transformers/models/llava_onevision/modeling_llava_onevision.py b/src/transformers/models/llava_onevision/modeling_llava_onevision.py index 697ea84fea5040..7c7ad78c86dbc6 100644 --- a/src/transformers/models/llava_onevision/modeling_llava_onevision.py +++ b/src/transformers/models/llava_onevision/modeling_llava_onevision.py @@ -233,6 +233,7 @@ class LlavaOnevisionPreTrainedModel(PreTrainedModel): _no_split_modules = ["LlavaOnevisionVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_static_cache = False # Qwen2 doesn't but llava has no reasons to not support _supports_quantized_cache = True diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 23a855fff25672..057edcd61bcd3d 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -37,6 +37,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -47,6 +48,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -427,6 +430,92 @@ def forward( return attn_output, None, past_key_value +class M2M100FlashAttention3(M2M100Attention): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + softmax_scale=None, + is_causal=self.is_causal, + ) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, None, past_key_value + + # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100 class M2M100EncoderLayer(nn.Module): def __init__(self, config: M2M100Config): @@ -501,6 +590,7 @@ def forward( M2M100_ATTENTION_CLASSES = { "eager": M2M100Attention, "flash_attention_2": M2M100FlashAttention2, + "flash_attention_3": M2M100FlashAttention3, } @@ -631,6 +721,7 @@ class M2M100PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module): std = self.config.init_std @@ -804,6 +895,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -884,7 +976,7 @@ def forward( # expand attention_mask if attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -980,6 +1072,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([M2M100DecoderLayer(config) for _ in range(config.decoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1090,7 +1183,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers combined_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1101,7 +1194,7 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -1232,7 +1325,7 @@ def __init__(self, config: M2M100Config): self.encoder = M2M100Encoder(config, self.shared) self.decoder = M2M100Decoder(config, self.shared) - if config._attn_implementation == "flash_attention_2": + if config._attn_implementation == "flash_attention_2" or config._attn_implementation == "flash_attention_3": logger.warning_once( "Attention with Flash Attention 2 does not support `layer_head_mask`. If you need this feature, please use standard attention." ) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 9455f21b2073ff..7597ea1ad693f8 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -46,6 +46,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -56,6 +57,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -410,6 +413,125 @@ def forward( return attn_output, attn_weights, past_key_value +class MBartFlashAttention3(MBartAttention): + """ + MBart flash attention module. This module inherits from `MBartAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # MBartFlashAttention3 attention does not support output_attentions + if output_attentions: + raise ValueError("MBartFlashAttention3 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MBart class MBartSdpaAttention(MBartAttention): def forward( @@ -521,6 +643,7 @@ def forward( "eager": MBartAttention, "sdpa": MBartSdpaAttention, "flash_attention_2": MBartFlashAttention2, + "flash_attention_3": MBartFlashAttention3, } @@ -745,6 +868,7 @@ class MBartPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MBartDecoderLayer", "MBartAttention"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): @@ -1043,7 +1167,10 @@ def forward( # expand attention_mask if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): attention_mask = attention_mask if 0 in attention_mask else None elif self.config._attn_implementation == "sdpa" and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA, fall back to @@ -1260,7 +1387,10 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self.config._attn_implementation == "sdpa" and not output_attentions and cross_attn_head_mask is None: @@ -1280,7 +1410,10 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self.config._attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index ffe16b27203301..c14438329827f5 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -41,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -52,6 +53,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "MistralConfig" @@ -402,6 +406,131 @@ def forward( return attn_output, attn_weights, past_key_value +class MistralFlashAttention3(MistralAttention): + """ + Mistral flash attention module. This module inherits from `MistralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ): + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += cache_position[0] + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Mistral # TODO(joao): add me back asap :) class MistralSdpaAttention(MistralAttention): @@ -495,6 +624,7 @@ def forward( MISTRAL_ATTENTION_CLASSES = { "eager": MistralAttention, "flash_attention_2": MistralFlashAttention2, + "flash_attention_3": MistralFlashAttention3, "sdpa": MistralSdpaAttention, } @@ -604,6 +734,7 @@ class MistralPreTrainedModel(PreTrainedModel): _no_split_modules = ["MistralDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True @@ -853,7 +984,7 @@ def _update_causal_mask( use_cache: bool, output_attentions: bool, ): - if self._attn_implementation == "flash_attention_2": + if self._attn_implementation == "flash_attention_2" or self._attn_implementation == "flash_attention_3": if attention_mask is not None and use_cache: is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0] if is_padding_right: diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index c7062e75b1085c..167054fbe61bd6 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -43,6 +43,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_torchdynamo_compiling, logging, replace_return_docstrings, @@ -54,6 +55,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. if is_torch_fx_available(): @@ -561,6 +565,133 @@ def forward( return attn_output, attn_weights, past_key_value +# TODO @longjie no longer copied from Mistral after static cache +class MixtralFlashAttention3(MixtralAttention): + """ + Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = ( + max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len + ) + + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral # TODO @longjie no longer copied from Mistral after static cache class MixtralSdpaAttention(MixtralAttention): @@ -656,6 +787,7 @@ def forward( MIXTRAL_ATTENTION_CLASSES = { "eager": MixtralAttention, "flash_attention_2": MixtralFlashAttention2, + "flash_attention_3": MixtralFlashAttention3, "sdpa": MixtralSdpaAttention, } @@ -857,6 +989,7 @@ class MixtralPreTrainedModel(PreTrainedModel): _no_split_modules = ["MixtralDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True @@ -1122,7 +1255,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index f720faac038e51..26cf599a627033 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -47,6 +47,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -59,6 +60,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + if TYPE_CHECKING: from ...generation.streamers import BaseStreamer @@ -439,6 +443,125 @@ def forward( return attn_output, attn_weights, past_key_value +class MusicgenFlashAttention3(MusicgenAttention): + """ + Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # MusicgenFlashAttention3 attention does not support output_attentions + if output_attentions: + raise ValueError("MusicgenFlashAttention3 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class MusicgenSdpaAttention(MusicgenAttention): def forward( self, @@ -566,6 +689,7 @@ def forward( "eager": MusicgenAttention, "sdpa": MusicgenSdpaAttention, "flash_attention_2": MusicgenFlashAttention2, + "flash_attention_3": MusicgenFlashAttention3, } @@ -703,6 +827,7 @@ class MusicgenPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): @@ -998,7 +1123,7 @@ def forward( if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) - if self.attn_implementation == "flash_attention_2": + if self.attn_implementation == "flash_attention_2" or self.attn_implementation == "flash_attention_3": attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on @@ -1016,7 +1141,7 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.attn_implementation == "flash_attention_2": + if self.attn_implementation == "flash_attention_2" or self.attn_implementation == "flash_attention_3": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on @@ -1664,6 +1789,7 @@ class MusicgenForConditionalGeneration(PreTrainedModel): main_input_name = "input_ids" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__( diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index a8a8fe96098952..5d83a44adf99b0 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -39,6 +39,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -51,6 +52,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + if TYPE_CHECKING: from ...generation.streamers import BaseStreamer @@ -455,6 +459,125 @@ def forward( return attn_output, attn_weights, past_key_value +class MusicgenMelodyFlashAttention3(MusicgenMelodyAttention): + """ + MusicgenMelody flash attention module. This module inherits from `MusicgenMelodyAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # MusicgenMelodyFlashAttention3 attention does not support output_attentions + if output_attentions: + raise ValueError("MusicgenMelodyFlashAttention3 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MusicgenMelody class MusicgenMelodySdpaAttention(MusicgenMelodyAttention): def forward( @@ -566,6 +689,7 @@ def forward( "eager": MusicgenMelodyAttention, "sdpa": MusicgenMelodySdpaAttention, "flash_attention_2": MusicgenMelodyFlashAttention2, + "flash_attention_3": MusicgenMelodyFlashAttention3, } @@ -662,6 +786,7 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): @@ -945,7 +1070,7 @@ def forward( input_shape = inputs_embeds.size()[:-1] - if self.attn_implementation == "flash_attention_2": + if self.attn_implementation == "flash_attention_2" or self.attn_implementation == "flash_attention_3": attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self.attn_implementation == "sdpa" and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on @@ -1590,6 +1715,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): main_input_name = "input_ids" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def __init__( diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 4d079b4dde104d..63afdf58cc68c3 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -27,6 +27,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter +from ...modeling_flash_attention_3_utils import _flash_attention_3_forward from ...modeling_flash_attention_utils import _flash_attention_forward from ...modeling_outputs import ( BaseModelOutputWithPast, @@ -468,6 +469,108 @@ def forward( return attn_output, attn_weights, past_key_value +class NemotronFlashAttention3(NemotronAttention): + """ + Nemotron flash attention module. This module inherits from `NemotronAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: Tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_3` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is not None: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (NemotronRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with LLAMA->NEMOTRON,Llama->Nemotron,llama->nemotron class NemotronSdpaAttention(NemotronAttention): """ @@ -563,6 +666,7 @@ def forward( NEMOTRON_ATTENTION_CLASSES = { "eager": NemotronAttention, "flash_attention_2": NemotronFlashAttention2, + "flash_attention_3": NemotronFlashAttention3, "sdpa": NemotronSdpaAttention, } @@ -677,6 +781,7 @@ class NemotronPreTrainedModel(PreTrainedModel): _no_split_modules = ["NemotronDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -920,7 +1025,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index b4bda8e2db5251..4f1ebb3cf690c1 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -41,6 +41,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -52,6 +53,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -498,6 +502,106 @@ def forward( return attn_output, attn_weights, past_key_value +class OlmoFlashAttention3(OlmoAttention): + """ + OLMo flash attention module. This module inherits from `OlmoAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (OlmoRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class OlmoSdpaAttention(OlmoAttention): """ OLMo attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -594,6 +698,7 @@ def forward( OLMO_ATTENTION_CLASSES = { "eager": OlmoAttention, "flash_attention_2": OlmoFlashAttention2, + "flash_attention_3": OlmoFlashAttention3, "sdpa": OlmoSdpaAttention, } @@ -704,6 +809,7 @@ class OlmoPreTrainedModel(PreTrainedModel): _no_split_modules = ["OlmoDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -959,7 +1065,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index a33338365312db..8efab00cc86cbf 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -34,6 +34,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -44,6 +45,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -572,6 +576,105 @@ def forward( return attn_output, attn_weights, past_key_value +class OlmoeFlashAttention3(OlmoeAttention): + """ + OLMoE flash attention module. This module inherits from `OlmoeAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_norm(self.q_proj(hidden_states)) + key_states = self.k_norm(self.k_proj(hidden_states)) + value_states = self.v_proj(hidden_states) + if self.config.clip_qkv is not None: + query_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + key_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + value_states.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (OlmoeRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class OlmoeSdpaAttention(OlmoeAttention): """ OLMoE attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -670,6 +773,7 @@ def forward( OLMOE_ATTENTION_CLASSES = { "eager": OlmoeAttention, "flash_attention_2": OlmoeFlashAttention2, + "flash_attention_3": OlmoeFlashAttention3, "sdpa": OlmoeSdpaAttention, } @@ -838,6 +942,7 @@ class OlmoePreTrainedModel(PreTrainedModel): _no_split_modules = ["OlmoeDecoderLayer"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1109,7 +1214,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 8f058171778efc..26a1499a7699e1 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -35,6 +35,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -45,6 +46,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -359,9 +363,119 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value +class OptFlashAttention3(OPTAttention): + """ + OPT flash attention module. This module inherits from `OPTAttention` as the weights of the module stays untouched. + The only required change would be on the forward pass where it needs to correctly call the public API of flash + attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, _, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) + # get key, value proj + if is_cross_attention and past_key_value is not None: + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + query_length = query_states.shape[1] + tgt_len = key_states.shape[-2] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + query_states = query_states.view(bsz, query_length, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = value_states.transpose(1, 2).view(bsz, tgt_len, self.num_heads, self.head_dim) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + query_length, + is_causal=self.is_causal, + ) + + attn_weights_reshaped = attn_output.reshape(bsz, query_length, self.num_heads * self.head_dim) + attn_output = self.out_proj(attn_weights_reshaped) + + if not output_attentions: + attn_weights_reshaped = None + + return attn_output, attn_weights_reshaped, past_key_value + + OPT_ATTENTION_CLASSES = { "eager": OPTAttention, "flash_attention_2": OptFlashAttention2, + "flash_attention_3": OptFlashAttention3, } @@ -488,6 +602,7 @@ class OPTPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["OPTDecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module): std = self.config.init_std @@ -604,6 +719,7 @@ def __init__(self, config: OPTConfig): self.layers = nn.ModuleList([OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -702,7 +818,7 @@ def forward( mask_seq_length = past_key_values_length + seq_length # embed positions - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers causal_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None attention_mask = ( diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index 39ee57d70bb5d8..b335e4c211f40a 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -124,6 +124,7 @@ class PaliGemmaPreTrainedModel(PreTrainedModel): _no_split_modules = ["PaliGemmaMultiModalProjector"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = False + _supports_flash_attn_3 = False _supports_cache_class = True _supports_quantized_cache = True _supports_static_cache = True diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index ccaa2c7fd29aae..02989bc9bc5471 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -783,7 +783,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 648d1653a3b503..93450596f97049 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -41,6 +41,7 @@ add_start_docstrings_to_model_forward, get_torch_version, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -52,6 +53,8 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward logger = logging.get_logger(__name__) @@ -575,6 +578,136 @@ def forward( return attn_output, attn_weights, past_key_value +class PhiFlashAttention3(PhiAttention): + """ + Phi flash attention module. This module inherits from `PhiAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # PhiFlashAttention3 attention does not support output_attentions + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + softmax_scale=None, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.dense(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class PhiSdpaAttention(PhiAttention): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -704,6 +837,7 @@ def forward( PHI_ATTENTION_CLASSES = { "eager": PhiAttention, "flash_attention_2": PhiFlashAttention2, + "flash_attention_3": PhiFlashAttention3, "sdpa": PhiSdpaAttention, } @@ -812,6 +946,7 @@ class PhiPreTrainedModel(PreTrainedModel): _no_split_modules = ["PhiDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True @@ -930,6 +1065,7 @@ def __init__(self, config: PhiConfig): self.rotary_emb = PhiRotaryEmbedding(config=config) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.gradient_checkpointing = False @@ -1074,7 +1210,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index ec395679ae6207..350070ed31e613 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -39,6 +39,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -50,6 +51,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" @@ -632,6 +636,146 @@ def forward( return attn_output, attn_weights, past_key_value +class Phi3FlashAttention3(Phi3Attention): + """ + Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Phi3FlashAttention3 attention does not support output_attentions + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = ( + max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len + ) + + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.qkv_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3 # TODO @Arthur no longer copied from LLama after static cache class Phi3SdpaAttention(Phi3Attention): @@ -729,6 +873,7 @@ def forward( PHI3_ATTENTION_CLASSES = { "eager": Phi3Attention, "flash_attention_2": Phi3FlashAttention2, + "flash_attention_3": Phi3FlashAttention3, "sdpa": Phi3SdpaAttention, } @@ -842,6 +987,7 @@ class Phi3PreTrainedModel(PreTrainedModel): _no_split_modules = ["Phi3DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True @@ -1095,7 +1241,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 93d91e160089e1..03a24a4676261f 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -685,6 +685,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) @@ -773,7 +774,7 @@ def forward( # expand attention_mask if attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: attention_mask = attention_mask if 0 in attention_mask else None elif self._use_sdpa and head_mask is None and not output_attentions: # output_attentions=True & head_mask can not be supported when using SDPA, fall back to @@ -870,6 +871,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -990,7 +992,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: @@ -1010,7 +1012,7 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index d0ea8ef0e376e0..97dd3f12a9f6b0 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -42,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -53,6 +54,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -539,6 +543,132 @@ def forward( return attn_output, attn_weights, past_key_value +class Qwen2FlashAttention3(Qwen2Attention): + """ + Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + kv_seq_len = key_states.shape[-2] + cache_position[0] + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class Qwen2SdpaAttention(Qwen2Attention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -638,6 +768,7 @@ def forward( QWEN2_ATTENTION_CLASSES = { "eager": Qwen2Attention, "flash_attention_2": Qwen2FlashAttention2, + "flash_attention_3": Qwen2FlashAttention3, "sdpa": Qwen2SdpaAttention, } @@ -647,7 +778,11 @@ def __init__(self, config: Qwen2Config, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - if config.sliding_window and config._attn_implementation != "flash_attention_2": + if ( + config.sliding_window + and config._attn_implementation != "flash_attention_2" + and config._attn_implementation != "flash_attention_3" + ): logger.warning_once( f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." @@ -754,6 +889,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -1013,7 +1149,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 14235bf0aaf64a..7e3e4062a7c7cd 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -30,6 +30,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -41,6 +42,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -338,6 +342,121 @@ def forward( return attn_output, attn_weights, past_key_value +class Qwen2AudioFlashAttention3(Qwen2AudioAttention): + """ + Qwen2Audio flash attention module. This module inherits from `Qwen2AudioAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_3'`. " + "Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers" + ) + # Qwen2AudioFlashAttention3 attention does not support output_attentions + if output_attentions: + raise ValueError("Qwen2AudioFlashAttention3 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim)) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self._shape(self.k_proj(current_states), -1, bsz) + value_states = self._shape(self.v_proj(current_states), -1, bsz) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim] + # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + causal_mask, + tgt_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.whisper.modeling_whisper.WhisperSdpaAttention with Whisper->Qwen2Audio class Qwen2AudioSdpaAttention(Qwen2AudioAttention): def forward( @@ -440,6 +559,7 @@ def forward( QWEN2AUDIO_ATTENTION_CLASSES = { "eager": Qwen2AudioAttention, "flash_attention_2": Qwen2AudioFlashAttention2, + "flash_attention_3": Qwen2AudioFlashAttention3, "sdpa": Qwen2AudioSdpaAttention, } @@ -543,6 +663,7 @@ class Qwen2AudioPreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2AudioAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True def _init_weights(self, module): # important: this ported version of Qwen2Audio isn't meant for training from scratch - only diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 6f483e50cde065..de9aa69bc620b3 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -43,6 +43,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -54,6 +55,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) _CHECKPOINT_FOR_DOC = "Qwen/Qwen1.5-MoE-A2.7B" @@ -621,6 +625,132 @@ def forward( return attn_output, attn_weights, past_key_value +class Qwen2MoeFlashAttention3(Qwen2MoeAttention): + """ + Qwen2Moe flash attention module, following Qwen2Moe attention module. This module inherits from `Qwen2MoeAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + kv_seq_len = key_states.shape[-2] + cache_position[0] + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.qwen2.modeling_qwen2.Qwen2SdpaAttention with Qwen2->Qwen2Moe class Qwen2MoeSdpaAttention(Qwen2MoeAttention): """ @@ -721,6 +851,7 @@ def forward( QWEN2MOE_ATTENTION_CLASSES = { "eager": Qwen2MoeAttention, "flash_attention_2": Qwen2MoeFlashAttention2, + "flash_attention_3": Qwen2MoeFlashAttention3, "sdpa": Qwen2MoeSdpaAttention, } @@ -913,6 +1044,7 @@ class Qwen2MoePreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2MoeDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True @@ -1188,7 +1320,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 4e4e04198c0bf1..ba71a944ea2492 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -44,6 +44,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -58,6 +59,12 @@ else: flash_attn_varlen_func = None +if is_flash_attn_3_available(): + from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func + + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward +else: + flash_attn_3_varlen_func = None logger = logging.get_logger(__name__) @@ -379,6 +386,29 @@ def forward( return attn_output +class VisionFlashAttention3(nn.Module): + def __init__(self, dim: int, num_heads: int = 16) -> None: + super().__init__() + self.num_heads = num_heads + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.proj = nn.Linear(dim, dim) + + def forward( + self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor = None + ) -> torch.Tensor: + seq_length = hidden_states.shape[0] + q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) + q = apply_rotary_pos_emb_vision(q.unsqueeze(0), rotary_pos_emb).squeeze(0) + k = apply_rotary_pos_emb_vision(k.unsqueeze(0), rotary_pos_emb).squeeze(0) + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + attn_output = flash_attn_3_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( + seq_length, -1 + ) + attn_output = self.proj(attn_output) + return attn_output + + class VisionSdpaAttention(nn.Module): def __init__(self, dim: int, num_heads: int = 16) -> None: super().__init__() @@ -410,6 +440,7 @@ def forward( QWEN2_VL_VISION_ATTENTION_CLASSES = { "eager": VisionAttention, "flash_attention_2": VisionFlashAttention2, + "flash_attention_3": VisionFlashAttention3, "sdpa": VisionSdpaAttention, } @@ -812,6 +843,144 @@ def forward( return attn_output, attn_weights, past_key_value +class Qwen2VLFlashAttention3(Qwen2VLAttention): + """ + Qwen2VL flash attention module, following Qwen2VL attention module. This module inherits from `Qwen2VLAttention` + as the weights of the module stays untouched. The only required change would be on the forward pass + where it needs to correctly call the public API of flash attention and deal with padding tokens + in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom + config.max_window_layers layers. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + query_states, key_states = apply_multimodal_rotary_pos_emb( + query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class Qwen2VLSdpaAttention(Qwen2VLAttention): """ Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -917,6 +1086,7 @@ def forward( QWEN2_VL_ATTENTION_CLASSES = { "eager": Qwen2VLAttention, "flash_attention_2": Qwen2VLFlashAttention2, + "flash_attention_3": Qwen2VLFlashAttention3, "sdpa": Qwen2VLSdpaAttention, } @@ -926,7 +1096,11 @@ def __init__(self, config: Qwen2VLConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size - if config.use_sliding_window and config._attn_implementation != "flash_attention_2": + if ( + config.use_sliding_window + and config._attn_implementation != "flash_attention_2" + and config._attn_implementation != "flash_attention_3" + ): logger.warning_once( f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " "unexpected results may be encountered." @@ -1033,6 +1207,7 @@ class Qwen2VLPreTrainedModel(PreTrainedModel): _no_split_modules = ["Qwen2VLDecoderLayer", "Qwen2VLVisionBlock"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True @@ -1274,7 +1449,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index a8f076fad79c76..313471a599651f 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -540,6 +540,7 @@ class RecurrentGemmaPreTrainedModel(PreTrainedModel): _no_split_modules = ["RecurrentGemmaDecoderLayer"] _skip_keys_device_placement = ["cache"] _supports_flash_attn_2 = False + _supports_flash_attn_3 = False _supports_sdpa = False # we can't compare with eager for now _supports_cache_class = True _supports_quantized_cache = True diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index c9a3494b88b486..d5eb67a7fb376e 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -33,6 +33,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, ) @@ -43,6 +44,9 @@ from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -682,6 +686,125 @@ def forward( return attn_output, attn_weights, past_key_value +class SEWFlashAttention3(SEWAttention): + """ + SEW flash attention module. This module inherits from `SEWAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # SEWFlashAttention3 attention does not support output_attentions + if output_attentions: + raise ValueError("SEWFlashAttention3 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class SEWSdpaAttention(SEWAttention): # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->SEW def forward( @@ -793,6 +916,7 @@ def forward( "eager": SEWAttention, "sdpa": SEWSdpaAttention, "flash_attention_2": SEWFlashAttention2, + "flash_attention_3": SEWFlashAttention3, } @@ -869,6 +993,7 @@ def __init__(self, config): self.upsample = SEWUpsampling(config) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -882,7 +1007,7 @@ def forward( all_self_attentions = () if output_attentions else None if attention_mask is not None: - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # make sure padded tokens output 0 hidden_states[~attention_mask] = 0.0 # 2d mask is passed through the layers @@ -979,6 +1104,7 @@ class SEWPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 1d35d1d44cfd97..fe4c9eb243c689 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -35,6 +35,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -46,6 +47,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -522,6 +526,89 @@ def forward( return attn_output, attn_weights +class SiglipFlashAttention3(SiglipAttention): + """ + SiglipAttention flash attention module. This module inherits from `SiglipAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + is_causal = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + class SiglipSdpaAttention(SiglipAttention): """ Siglip attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -591,6 +678,7 @@ def forward( SIGLIP_ATTENTION_CLASSES = { "eager": SiglipAttention, "flash_attention_2": SiglipFlashAttention2, + "flash_attention_3": SiglipFlashAttention3, "sdpa": SiglipSdpaAttention, } @@ -677,6 +765,7 @@ class SiglipPreTrainedModel(PreTrainedModel): "SiglipMultiheadAttentionPoolingHead", ] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): @@ -930,6 +1019,7 @@ def __init__(self, config: SiglipTextConfig): self.head = nn.Linear(embed_dim, embed_dim) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) @@ -962,7 +1052,7 @@ def forward( # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. # expand attention_mask - if attention_mask is not None and not self._use_flash_attention_2: + if attention_mask is not None and not self._use_flash_attention_2 and not self._use_flash_attention_3: # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index d91c0832ed33da..11348590ce0cc4 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -42,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -52,6 +53,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -672,10 +676,114 @@ def forward( return attn_output, attn_weights, past_key_value +class StableLmFlashAttention3(StableLmAttention): + """ + StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # StableLmFlashAttention3 attention does not support output_attentions + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if self.qk_layernorm: + query_states = self.q_layernorm(query_states) + key_states = self.k_layernorm(key_states) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + + # Partial rotary embedding + query_rot, query_pass = ( + query_states[..., : self.rotary_ndims], + query_states[..., self.rotary_ndims :], + ) + key_rot, key_pass = ( + key_states[..., : self.rotary_ndims], + key_states[..., self.rotary_ndims :], + ) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + + # [batch_size, seq_length, num_heads, head_dim] + query_states = torch.cat((query_rot, query_pass), dim=-1) + key_states = torch.cat((key_rot, key_pass), dim=-1) + + if past_key_value is not None: + cache_kwargs = { + "sin": sin, + "cos": cos, + "partial_rotation_size": self.rotary_ndims, + "cache_position": cache_position, + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + ATTENTION_CLASSES = { "eager": StableLmAttention, "sdpa": StableLmSdpaAttention, "flash_attention_2": StableLmFlashAttention2, + "flash_attention_3": StableLmFlashAttention3, } @@ -799,6 +907,7 @@ class StableLmPreTrainedModel(PreTrainedModel): _no_split_modules = ["StableLmDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True _supports_sdpa = True _supports_quantized_cache = True @@ -1058,7 +1167,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 0be37c4e1fb91c..bb0311477be828 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -42,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_torchdynamo_compiling, logging, @@ -53,6 +54,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -512,6 +516,131 @@ def forward( return attn_output, attn_weights, past_key_value +class Starcoder2FlashAttention3(Starcoder2Attention): + """ + Starcoder2 flash attention module. This module inherits from `Starcoder2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 + ): + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + if position_embeddings is None: + logger.warning_once( + "The attention layers in this model are transitioning from computing the RoPE embeddings internally " + "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " + "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " + "removed and `position_embeddings` will be mandatory." + ) + cos, sin = self.rotary_emb(value_states, position_ids) + else: + cos, sin = position_embeddings + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + kv_seq_len = key_states.shape[-2] + cache_position[0] + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + attn_output = nn.functional.dropout(attn_output, p=self.residual_dropout, training=self.training) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + # Copied from transformers.models.mixtral.modeling_mixtral.MixtralSdpaAttention with Mixtral->Starcoder2 class Starcoder2SdpaAttention(Starcoder2Attention): """ @@ -614,6 +743,7 @@ def forward( STARCODER2_ATTENTION_CLASSES = { "eager": Starcoder2Attention, "flash_attention_2": Starcoder2FlashAttention2, + "flash_attention_3": Starcoder2FlashAttention3, "sdpa": Starcoder2SdpaAttention, } @@ -728,6 +858,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): _no_split_modules = ["Starcoder2DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_quantized_cache = True @@ -988,7 +1119,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 4202f680437c53..9e71d7c4f853fe 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -35,6 +35,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -45,6 +46,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -714,6 +718,125 @@ def forward( return attn_output, attn_weights, past_key_value +class UniSpeechFlashAttention3(UniSpeechAttention): + """ + UniSpeech flash attention module. This module inherits from `UniSpeechAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # UniSpeechFlashAttention3 attention does not support output_attentions + if output_attentions: + raise ValueError("UniSpeechFlashAttention3 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class UniSpeechSdpaAttention(UniSpeechAttention): # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->UniSpeech def forward( @@ -825,6 +948,7 @@ def forward( "eager": UniSpeechAttention, "sdpa": UniSpeechSdpaAttention, "flash_attention_2": UniSpeechFlashAttention2, + "flash_attention_3": UniSpeechFlashAttention3, } @@ -972,6 +1096,7 @@ def __init__(self, config): self.layers = nn.ModuleList([UniSpeechEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -988,7 +1113,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1060,6 +1185,7 @@ def __init__(self, config): ) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1076,7 +1202,7 @@ def forward( # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1218,6 +1344,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index bfb2cbfa4f55da..0ec5b1c9836f9d 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -42,6 +42,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_peft_available, logging, @@ -53,6 +54,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -731,6 +735,125 @@ def forward( return attn_output, attn_weights, past_key_value +class UniSpeechSatFlashAttention3(UniSpeechSatAttention): + """ + UniSpeechSat flash attention module. This module inherits from `UniSpeechSatAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # UniSpeechSatFlashAttention3 attention does not support output_attentions + if output_attentions: + raise ValueError("UniSpeechSatFlashAttention3 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class UniSpeechSatSdpaAttention(UniSpeechSatAttention): # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->UniSpeechSat def forward( @@ -842,6 +965,7 @@ def forward( "eager": UniSpeechSatAttention, "sdpa": UniSpeechSatSdpaAttention, "flash_attention_2": UniSpeechSatFlashAttention2, + "flash_attention_3": UniSpeechSatFlashAttention3, } @@ -989,6 +1113,7 @@ def __init__(self, config): self.layers = nn.ModuleList([UniSpeechSatEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1005,7 +1130,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1077,6 +1202,7 @@ def __init__(self, config): ) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1093,7 +1219,7 @@ def forward( # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1235,6 +1361,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/video_llava/modeling_video_llava.py b/src/transformers/models/video_llava/modeling_video_llava.py index 9ae80be65ae4b6..7fbc36945d9a0a 100644 --- a/src/transformers/models/video_llava/modeling_video_llava.py +++ b/src/transformers/models/video_llava/modeling_video_llava.py @@ -126,6 +126,7 @@ class VideoLlavaPreTrainedModel(PreTrainedModel): _no_split_modules = ["VideoLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True def _init_weights(self, module): diff --git a/src/transformers/models/vipllava/modeling_vipllava.py b/src/transformers/models/vipllava/modeling_vipllava.py index 53a3213697193e..f42e9333c39bec 100644 --- a/src/transformers/models/vipllava/modeling_vipllava.py +++ b/src/transformers/models/vipllava/modeling_vipllava.py @@ -132,6 +132,7 @@ class VipLlavaPreTrainedModel(PreTrainedModel): _no_split_modules = ["VipLlavaVisionAttention"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_cache_class = True def _init_weights(self, module): diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index f1d021b58ee538..1333f3ff15ff60 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -45,6 +45,7 @@ add_start_docstrings_to_model_forward, cached_file, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, is_peft_available, is_safetensors_available, @@ -64,6 +65,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -778,6 +782,125 @@ def forward( return attn_output, attn_weights, past_key_value +class Wav2Vec2FlashAttention3(Wav2Vec2Attention): + """ + Wav2Vec2 flash attention module. This module inherits from `Wav2Vec2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # Wav2Vec2FlashAttention3 attention does not support output_attentions + if output_attentions: + raise ValueError("Wav2Vec2FlashAttention3 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, q_len, _ = hidden_states.size() + + # get query proj + query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0].transpose(1, 2) + value_states = past_key_value[1].transpose(1, 2) + elif is_cross_attention: + # cross_attentions + key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) + value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) + value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + else: + # self_attention + key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) + value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class Wav2Vec2SdpaAttention(Wav2Vec2Attention): # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Wav2Vec2 def forward( @@ -889,6 +1012,7 @@ def forward( "eager": Wav2Vec2Attention, "sdpa": Wav2Vec2SdpaAttention, "flash_attention_2": Wav2Vec2FlashAttention2, + "flash_attention_3": Wav2Vec2FlashAttention3, } @@ -1006,6 +1130,7 @@ def __init__(self, config): self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1022,7 +1147,7 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1093,6 +1218,7 @@ def __init__(self, config): ) self.gradient_checkpointing = False self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" def forward( self, @@ -1109,7 +1235,7 @@ def forward( # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: + if self._use_flash_attention_2 or self._use_flash_attention_3: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: @@ -1331,6 +1457,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): main_input_name = "input_values" supports_gradient_checkpointing = True _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True def _init_weights(self, module): diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index b82b978e5e6d95..b78e3ea6925e79 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -39,6 +39,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal_2_10, logging, replace_return_docstrings, @@ -50,6 +51,9 @@ if is_flash_attn_2_available(): from ...modeling_flash_attention_utils import _flash_attention_forward +if is_flash_attn_3_available(): + from ...modeling_flash_attention_3_utils import _flash_attention_3_forward + logger = logging.get_logger(__name__) @@ -523,6 +527,121 @@ def forward( return attn_output, attn_weights, past_key_value +class WhisperFlashAttention3(WhisperAttention): + """ + Whisper flash attention module. This module inherits from `WhisperAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[EncoderDecoderCache] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "The `static` cache implementation is not compatible with `attn_implementation='flash_attention_3'`. " + "Use `attn_implementation='sdpa'` in the meantime, and open an issue at https://github.com/huggingface/transformers" + ) + # WhisperFlashAttention3 attention does not support output_attentions + if output_attentions: + raise ValueError("WhisperFlashAttention3 attention does not support output_attentions") + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = torch.reshape(self.q_proj(hidden_states), (bsz, tgt_len, self.num_heads, self.head_dim)) + + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value.is_updated[self.layer_idx] = True + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + + # use key_value_states if cross attention + current_states = key_value_states if key_value_states is not None else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self._shape(self.k_proj(current_states), -1, bsz) + value_states = self._shape(self.v_proj(current_states), -1, bsz) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim] + # We would need to refactor the KV cache to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + causal_mask = attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = _flash_attention_3_forward( + query_states, + key_states, + value_states, + causal_mask, + tgt_len, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, tgt_len, -1) + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class WhisperSdpaAttention(WhisperAttention): def forward( self, @@ -624,6 +743,7 @@ def forward( WHISPER_ATTENTION_CLASSES = { "eager": WhisperAttention, "flash_attention_2": WhisperFlashAttention2, + "flash_attention_3": WhisperFlashAttention3, "sdpa": WhisperSdpaAttention, } @@ -823,6 +943,7 @@ class WhisperPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = ["WhisperEncoderLayer", "WhisperDecoderLayer"] _supports_flash_attn_2 = True + _supports_flash_attn_3 = True _supports_sdpa = True _supports_cache_class = True _supports_static_cache = True @@ -1166,6 +1287,7 @@ def __init__(self, config: WhisperConfig): [WhisperDecoderLayer(config, layer_idx) for layer_idx in range(config.decoder_layers)] ) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_flash_attention_3 = config._attn_implementation == "flash_attention_3" self._use_sdpa = config._attn_implementation == "sdpa" self.layer_norm = nn.LayerNorm(config.d_model) @@ -1428,7 +1550,10 @@ def _update_causal_mask( past_key_values: Cache, output_attentions: bool, ): - if self.config._attn_implementation == "flash_attention_2": + if ( + self.config._attn_implementation == "flash_attention_2" + or self.config._attn_implementation == "flash_attention_3" + ): if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index eee350349f5565..92a9b1fb675623 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -130,6 +130,7 @@ is_faiss_available, is_fbgemm_gpu_available, is_flash_attn_2_available, + is_flash_attn_3_available, is_flash_attn_greater_or_equal, is_flash_attn_greater_or_equal_2_10, is_flax_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index ad8b649aaa4e84..4f4e70b122be6f 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -62,6 +62,14 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ except ImportError: # If the package can't be imported, it's not available package_exists = False + elif pkg_name == "flash_attn_interface": + try: + package = importlib.import_module(pkg_name) + package_version = getattr(package, "__version__", "N/A") + package_exists = True + except ImportError: + # If the package can't be imported, it's not available + package_exists = False else: # For packages other than "torch", don't attempt the fallback and set as not available package_exists = False @@ -899,6 +907,16 @@ def is_flash_attn_greater_or_equal(library_version: str): return version.parse(importlib.metadata.version("flash_attn")) >= version.parse(library_version) +def is_flash_attn_3_available(): + if not is_flash_attn_2_available(): + return False + + if not _is_package_available("flash_attn_interface"): + return False + + return True + + def is_torchdistx_available(): return _torchdistx_available