From 52920b5dd5ad3b5e94209ef392ab5ceccbb1c869 Mon Sep 17 00:00:00 2001 From: Joao Gante Date: Thu, 19 Sep 2024 17:42:47 +0100 Subject: [PATCH] Cache: don't throw warnings on `gemma2` when instantiating a new cache (#33595) --- src/transformers/cache_utils.py | 10 ++++- .../models/gemma2/modeling_gemma2.py | 41 +++++++------------ src/transformers/models/mimi/modeling_mimi.py | 12 +++++- tests/models/gemma2/test_modeling_gemma2.py | 5 +++ 4 files changed, 38 insertions(+), 30 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 0671157e447038..d42b15c14abf9b 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1660,7 +1660,15 @@ def get_max_length(self) -> Optional[int]: return self.max_cache_len def get_seq_length(self, layer_idx: Optional[int] = 0): - return None + # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's + # limit the check to the first batch member and head dimension. + # TODO: deprecate this function in favor of `cache_position` + if layer_idx != 0: + raise ValueError( + "`get_seq_length` on `HybridCache` may get inconsistent results depending on the layer index. " + "Using the `layer_idx` argument is not supported." + ) + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def reset(self): """Resets the cache values while preserving the objects""" diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 1909ef78501559..be964c9aed018a 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -710,20 +710,13 @@ def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False): config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + past_key_values (`HybridCache`, *optional*): Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. - Two formats are allowed: - - a [`~cache_utils.Cache`] instance, see our - [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); - - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of - shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy - cache format. - - The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the - legacy cache format will be returned. + Gemma 2 uses a unique cache class, [`HybridCache`], and does not guarantee full compatibility with other + cache classes. If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` @@ -789,7 +782,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[HybridCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -818,19 +811,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if cache_position is None: - if past_key_values is None: - cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) - else: - raise ValueError("When `past_key_values` is passed, `cache_position` must be too") - - # Probably a forward call with caching, so we set up cache for one call only - if use_cache and past_key_values is None and not self.training: - logger.warning_once( - "You are calling the model with `use_cache=True` but didn't pass `past_key_values` while not training. ", - "If you want to compute with cache, make sure to pass an instance of `HybridCache`. An empty `HybridCache` instance " - "will be created for this call. See for more: (https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.HybridCache)", - ) + # Instantiate an empty cache if needed. + if use_cache and past_key_values is None: batch_size, seq_len, _ = inputs_embeds.shape past_key_values = HybridCache( self.config, @@ -840,6 +822,11 @@ def forward( dtype=inputs_embeds.dtype, ) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: position_ids = cache_position.unsqueeze(0) @@ -912,7 +899,7 @@ def _update_causal_mask( attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_key_values: Cache, + past_key_values: HybridCache, output_attentions: bool, ): # Flash Attention currently doesn't support static cache but Gemma2 work only with static cache. @@ -981,7 +968,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[HybridCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1202,7 +1189,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + past_key_values: Optional[HybridCache] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index db36250b3d89df..d91b057ef28ec4 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1000,8 +1000,16 @@ def forward( ) use_cache = False - if use_cache and past_key_values is None and not self.training: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) + if use_cache and not isinstance(past_key_values, Cache): + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) if cache_position is None: past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 diff --git a/tests/models/gemma2/test_modeling_gemma2.py b/tests/models/gemma2/test_modeling_gemma2.py index 918ed847f83d9e..4e7b3553460f89 100644 --- a/tests/models/gemma2/test_modeling_gemma2.py +++ b/tests/models/gemma2/test_modeling_gemma2.py @@ -86,10 +86,15 @@ def setUp(self): def test_model_outputs_equivalence(self, **kwargs): pass + @parameterized.expand([("float16",), ("bfloat16",), ("float32",)]) @unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different") def test_eager_matches_sdpa_inference(self): pass + @unittest.skip("Gemma2's eager attn/sdpa attn outputs are expected to be different") + def test_eager_matches_sdpa_generate(self): + pass + @parameterized.expand([("random",), ("same",)]) @unittest.skip("Gemma2 has HybridCache which is not compatible with assisted decoding") def test_assisted_decoding_matches_greedy_search(self, assistant_type):