From d8944058ac3e29926c014d278643ae72d9a8ed4d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 17:19:21 +0200 Subject: [PATCH 01/48] Add new dynamic cache --- src/transformers/cache_utils.py | 79 +++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 4e4a1ee26c12d7..37cc8f24b221b6 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -545,6 +545,85 @@ def batch_select_indices(self, indices: torch.Tensor): self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] +class DynamicSlidingWindowCache(DynamicCache): + """ + A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window. + This is the default for generative models with sliding window attention. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window, head_dim]` if seq_len >= sliding_window. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicSlidingWindowCache + + >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + + >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicSlidingWindowCache(model.config.sliding_window) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + DynamicSlidingWindowCache() + ``` + """ + + def __init__(self, sliding_window: int): + super().__init__() + self.sliding_window = sliding_window + self.slicing_ + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Discard previous + tokens according to the sliding window if needed. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + # Add only up to sliding window size if larger + self.key_cache.append(key_states[:, :, -self.sliding_window:, :]) + self.value_cache.append(value_states[:, :, -self.sliding_window:, :]) + else: + new_seq_len = key_states.shape[-2] + current_seq_len = self.get_seq_length(layer_idx) + if new_seq_len + current_seq_len > self.sliding_window: + # We need to slice + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx][:, :, -(self.sliding_window-new_seq_len):, :], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx][:, :, -(self.sliding_window-new_seq_len):, :], value_states], dim=-2) + else: + # Similar to DynamicCache + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + class OffloadedCache(DynamicCache): """ A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. From 3b0984b99e73bca779b6288b953f488479323acd Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 17:46:05 +0200 Subject: [PATCH 02/48] Add cache by default in generate for models supporting it --- src/transformers/generation/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 5da4878513eb22..e49e9d9636bbac 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -28,6 +28,7 @@ from ..cache_utils import ( Cache, DynamicCache, + DynamicSlidingWindowCache, EncoderDecoderCache, OffloadedCache, QuantizedCacheConfig, From 345e695d73b6828886d99bbb1ecfcce952e9d235 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 18:02:54 +0200 Subject: [PATCH 03/48] Add to __init__ and correct typo --- src/transformers/__init__.py | 1 + src/transformers/cache_utils.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index ab829c6894c0f9..9bae0dff2d0939 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1268,6 +1268,7 @@ "Cache", "CacheConfig", "DynamicCache", + "DynamicSlidingWindowCache", "EncoderDecoderCache", "HQQQuantizedCache", "HybridCache", diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 37cc8f24b221b6..99d2d418bd0f9c 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -574,7 +574,6 @@ class DynamicSlidingWindowCache(DynamicCache): def __init__(self, sliding_window: int): super().__init__() self.sliding_window = sliding_window - self.slicing_ def update( self, From 38e82b5432189a95f265054e9e0f7898ea290625 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 19:41:44 +0200 Subject: [PATCH 04/48] Correct output if prefill larger than sliding window + compatibility --- src/transformers/cache_utils.py | 35 ++++++++++++++++++++++++---- src/transformers/generation/utils.py | 6 ++--- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 99d2d418bd0f9c..a517d9ef90b532 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -606,15 +606,17 @@ def update( # Update the cache if len(self.key_cache) <= layer_idx: # Add only up to sliding window size if larger - self.key_cache.append(key_states[:, :, -self.sliding_window:, :]) - self.value_cache.append(value_states[:, :, -self.sliding_window:, :]) + self.key_cache.append(key_states[..., -self.sliding_window:, :]) + self.value_cache.append(value_states[..., -self.sliding_window:, :]) + # We should return full states during prefill even though we only save up to sliding window + return key_states, value_states else: new_seq_len = key_states.shape[-2] current_seq_len = self.get_seq_length(layer_idx) if new_seq_len + current_seq_len > self.sliding_window: # We need to slice - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx][:, :, -(self.sliding_window-new_seq_len):, :], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx][:, :, -(self.sliding_window-new_seq_len):, :], value_states], dim=-2) + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx][..., -(self.sliding_window-new_seq_len):, :], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx][..., -(self.sliding_window-new_seq_len):, :], value_states], dim=-2) else: # Similar to DynamicCache self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) @@ -622,6 +624,31 @@ def update( return self.key_cache[layer_idx], self.value_cache[layer_idx] + def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSlidingWindowCache"]: + """Needs to be overriden because DynamicSlidingWindowCache takes an __init__() argument.""" + out = [] + for i in range(0, full_batch_size, split_size): + current_split = DynamicSlidingWindowCache(self.sliding_window) + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": + """Needs to be overriden because DynamicSlidingWindowCache takes an __init__() argument.""" + cache = cls(splits[0].sliding_window) + for idx in range(len(splits[0])): + layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) + cache.update(layer_keys, layer_values, idx) + return cache + + # Legacy format does not really make sense here even though it is a DynamicCache -> we set methods to None + from_legacy_cache = None + to_legacy_cache = None + class OffloadedCache(DynamicCache): """ diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index e49e9d9636bbac..0f61cc84db8884 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4604,10 +4604,8 @@ def _concat(data): if isinstance(data[0], torch.Tensor): return torch.cat(data, dim=0) # New cache format - elif isinstance(data[0], DynamicCache): - return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) - elif isinstance(data[0], EncoderDecoderCache): - return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers) + elif isinstance(data[0], (DynamicCache, EncoderDecoderCache)): + return data[0].__class__.from_batch_splits(data) elif isinstance(data[0], tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): From c46a92a23b4c25fea9bf9fc9b9df9d8daba98dde Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 22:08:36 +0200 Subject: [PATCH 05/48] Add legacy format handling --- src/transformers/cache_utils.py | 21 +++++++++++++++------ src/transformers/generation/utils.py | 2 +- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index a517d9ef90b532..df34b8de6919fe 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -623,9 +623,21 @@ def update( self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) return self.key_cache[layer_idx], self.value_cache[layer_idx] + + @classmethod + def from_legacy_cache(cls, sliding_window: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicSlidingWindowCache`. Used for + backward compatibility.""" + cache = cls(sliding_window) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSlidingWindowCache"]: - """Needs to be overriden because DynamicSlidingWindowCache takes an __init__() argument.""" + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" out = [] for i in range(0, full_batch_size, split_size): current_split = DynamicSlidingWindowCache(self.sliding_window) @@ -637,17 +649,14 @@ def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSli @classmethod def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": - """Needs to be overriden because DynamicSlidingWindowCache takes an __init__() argument.""" + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" cache = cls(splits[0].sliding_window) for idx in range(len(splits[0])): layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) cache.update(layer_keys, layer_values, idx) return cache - - # Legacy format does not really make sense here even though it is a DynamicCache -> we set methods to None - from_legacy_cache = None - to_legacy_cache = None class OffloadedCache(DynamicCache): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0f61cc84db8884..919a0bcc6b794f 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2354,7 +2354,7 @@ def typeerror(): should_convert_cache = generation_config.return_legacy_cache is_user_defined_cache = user_defined_cache is not None is_default_cache_type = ( - type(result.past_key_values) == DynamicCache # noqa E721 + type(result.past_key_values) in (DynamicCache, DynamicSlidingWindowCache) # noqa E721 or ( isinstance(result.past_key_values, EncoderDecoderCache) and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721 From 02b8506eb542545011c72c6f2977ec4e96e8d323 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 13:01:41 +0200 Subject: [PATCH 06/48] style --- src/transformers/cache_utils.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index df34b8de6919fe..08eb329cab8373 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -606,8 +606,8 @@ def update( # Update the cache if len(self.key_cache) <= layer_idx: # Add only up to sliding window size if larger - self.key_cache.append(key_states[..., -self.sliding_window:, :]) - self.value_cache.append(value_states[..., -self.sliding_window:, :]) + self.key_cache.append(key_states[..., -self.sliding_window :, :]) + self.value_cache.append(value_states[..., -self.sliding_window :, :]) # We should return full states during prefill even though we only save up to sliding window return key_states, value_states else: @@ -615,17 +615,23 @@ def update( current_seq_len = self.get_seq_length(layer_idx) if new_seq_len + current_seq_len > self.sliding_window: # We need to slice - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx][..., -(self.sliding_window-new_seq_len):, :], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx][..., -(self.sliding_window-new_seq_len):, :], value_states], dim=-2) + self.key_cache[layer_idx] = torch.cat( + [self.key_cache[layer_idx][..., -(self.sliding_window - new_seq_len) :, :], key_states], dim=-2 + ) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx][..., -(self.sliding_window - new_seq_len) :, :], value_states], dim=-2 + ) else: # Similar to DynamicCache self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) return self.key_cache[layer_idx], self.value_cache[layer_idx] - + @classmethod - def from_legacy_cache(cls, sliding_window: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + def from_legacy_cache( + cls, sliding_window: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "DynamicCache": """Converts a cache in the legacy cache format into an equivalent `DynamicSlidingWindowCache`. Used for backward compatibility.""" cache = cls(sliding_window) @@ -646,7 +652,7 @@ def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSli current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] out.append(current_split) return out - + @classmethod def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in From 7a98aac880bf130669b295694997270e00965ae7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 14:29:44 +0200 Subject: [PATCH 07/48] add docs --- docs/source/en/internal/generation_utils.md | 7 +++++++ src/transformers/cache_utils.py | 5 +++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index a81d202c6634af..3d4dfef7026732 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -367,6 +367,13 @@ A [`Constraint`] can be used to force the generation to include specific tokens - to_legacy_cache - from_legacy_cache +[[autodoc]] DynamicSlidingWindowCache + - update + - get_seq_length + - reorder_cache + - to_legacy_cache + - from_legacy_cache + [[autodoc]] QuantizedCache - update - get_seq_length diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 08eb329cab8373..256126ff9b5e88 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -350,7 +350,8 @@ def validate(self): class DynamicCache(Cache): """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. + A cache that grows dynamically as more tokens are generated. This is the default for generative models without sliding window attention + (see `DynamicSlidingWindowCache` in this case). It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]`. @@ -548,7 +549,7 @@ def batch_select_indices(self, indices: torch.Tensor): class DynamicSlidingWindowCache(DynamicCache): """ A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window. - This is the default for generative models with sliding window attention. + This is the default for generative models with sliding window attention (except for assisted decoding where `DynamicCache` is used). It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window, head_dim]` if seq_len >= sliding_window. From ebe6dc91b171793c42eef2ed4d893e340b4238e9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 14:41:33 +0200 Subject: [PATCH 08/48] fix import --- src/transformers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 9bae0dff2d0939..af795e60ba4744 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -6157,6 +6157,7 @@ Cache, CacheConfig, DynamicCache, + DynamicSlidingWindowCache, EncoderDecoderCache, HQQQuantizedCache, HybridCache, From af95f2ad467b64ee8105733526b749225f1c50fb Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 14:49:50 +0200 Subject: [PATCH 09/48] Update dummy_pt_objects.py --- src/transformers/utils/dummy_pt_objects.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 048de1cc8ae77a..514277b1076687 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -37,6 +37,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class DynamicSlidingWindowCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class EncoderDecoderCache(metaclass=DummyObject): _backends = ["torch"] From 08d1a9f09b15a0da94ccd2d2223ef76ec213d034 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 16:59:30 +0200 Subject: [PATCH 10/48] Update test --- tests/generation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1727aed1117bc6..19920dab7d0348 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,7 +62,7 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import DynamicCache, DynamicSlidingWindowCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, From b73655ae7d7a0482aec056183f750a9c2169fcae Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 17:01:44 +0200 Subject: [PATCH 11/48] style --- tests/generation/test_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 19920dab7d0348..0b2f47a9526556 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,7 +62,13 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, DynamicSlidingWindowCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import ( + DynamicCache, + DynamicSlidingWindowCache, + EncoderDecoderCache, + QuantoQuantizedCache, + StaticCache, + ) from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, From ff16af0165678c3942cc9009e9cde2240f4bd6c2 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 18:25:25 +0200 Subject: [PATCH 12/48] update cache conversion in test --- tests/generation/test_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 0b2f47a9526556..691ab457c488b1 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1816,7 +1816,10 @@ def test_new_cache_format(self, num_beams, do_sample): ) new_cache = new_results.past_key_values - legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) + if cache_cls == DynamicSlidingWindowCache: + legacy_cache_converted = cache_cls.from_legacy_cache(config.sliding_window, legacy_results.past_key_values) + else: + legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) for layer_idx in range(len(new_cache)): for kv_idx in range(len(new_cache[layer_idx])): # TODO: @raushan, please look into this for new cache format From 5e3fef01ef257e12ef80c6ac5c297da3197e8fe2 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Mon, 23 Sep 2024 16:34:41 +0200 Subject: [PATCH 13/48] style --- tests/generation/test_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 691ab457c488b1..c4177b79c35da2 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1817,7 +1817,9 @@ def test_new_cache_format(self, num_beams, do_sample): new_cache = new_results.past_key_values if cache_cls == DynamicSlidingWindowCache: - legacy_cache_converted = cache_cls.from_legacy_cache(config.sliding_window, legacy_results.past_key_values) + legacy_cache_converted = cache_cls.from_legacy_cache( + config.sliding_window, legacy_results.past_key_values + ) else: legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) for layer_idx in range(len(new_cache)): From 3d1bfd0b52f052d2918fcc863f43ac43d5538c00 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Sep 2024 10:54:45 +0200 Subject: [PATCH 14/48] Allow the cache to support new states of more than 1 token, even after prefill stage --- src/transformers/cache_utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 256126ff9b5e88..62e0ca7c2fc526 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -587,6 +587,10 @@ def update( Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Discard previous tokens according to the sliding window if needed. + Note: we always keep `sliding_window` tokens in the cache, instead of the `sliding_window - 1` tokens that + are strictly necesary. This allows to roll back one token in the past with `cache.crop(-1)` in contrastive search. + Assisted decoding would need to roll back additional tokens, and is therefore not supported with this Cache class. + Parameters: key_states (`torch.Tensor`): The new key states to cache. @@ -615,13 +619,17 @@ def update( new_seq_len = key_states.shape[-2] current_seq_len = self.get_seq_length(layer_idx) if new_seq_len + current_seq_len > self.sliding_window: - # We need to slice - self.key_cache[layer_idx] = torch.cat( - [self.key_cache[layer_idx][..., -(self.sliding_window - new_seq_len) :, :], key_states], dim=-2 + # We may need to return longer states (e.g. to continue generation with previous cache, with added tokens), but we only keep + # the last `sliding_window` states in the cache for next forward + full_key_states = torch.cat( + [self.key_cache[layer_idx][..., -(self.sliding_window - 1) :, :], key_states], dim=-2 ) - self.value_cache[layer_idx] = torch.cat( - [self.value_cache[layer_idx][..., -(self.sliding_window - new_seq_len) :, :], value_states], dim=-2 + full_value_states = torch.cat( + [self.value_cache[layer_idx][..., -(self.sliding_window - 1) :, :], value_states], dim=-2 ) + self.key_cache[layer_idx] = full_key_states[..., -self.sliding_window :, :] + self.value_cache[layer_idx] = full_value_states[..., -self.sliding_window :, :] + return full_key_states, full_value_states else: # Similar to DynamicCache self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) From 6a02bdca13cf19e0234cd027575058e28bac2a6f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Sep 2024 10:56:32 +0200 Subject: [PATCH 15/48] Update cache_utils.py --- src/transformers/cache_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 62e0ca7c2fc526..1df6f6e36ef03b 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -587,10 +587,10 @@ def update( Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Discard previous tokens according to the sliding window if needed. - Note: we always keep `sliding_window` tokens in the cache, instead of the `sliding_window - 1` tokens that + Note: we always keep `sliding_window` tokens in the cache if it is full, instead of the `sliding_window - 1` tokens that are strictly necesary. This allows to roll back one token in the past with `cache.crop(-1)` in contrastive search. Assisted decoding would need to roll back additional tokens, and is therefore not supported with this Cache class. - + Parameters: key_states (`torch.Tensor`): The new key states to cache. From 838712ded0619b5da09f05199f7dfe7a276ef5cc Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 24 Sep 2024 13:37:39 +0200 Subject: [PATCH 16/48] maybe change test --- tests/generation/test_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c4177b79c35da2..dec7e46f6a5eeb 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1726,7 +1726,10 @@ def test_generate_continue_from_past_key_values(self): outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True) # Continue from the tokens generated above, preparing the inputs accordingly - inputs["past_key_values"] = outputs_cached.past_key_values + if getattr(config, "sliding_window", None) is not None: + inputs["past_key_values"] = DynamicSlidingWindowCache(config.sliding_window, outputs_cached.past_key_values) + else: + inputs["past_key_values"] = outputs_cached.past_key_values new_attention_len = outputs_cached.sequences.shape[-1] if config.is_encoder_decoder: inputs["decoder_input_ids"] = outputs_cached.sequences From 6afd20d3d7f06eb44547fa9ff600b35d2cfe9332 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 15:27:28 +0200 Subject: [PATCH 17/48] revert tests diffs --- tests/generation/test_utils.py | 20 +++----------------- 1 file changed, 3 insertions(+), 17 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index dec7e46f6a5eeb..1727aed1117bc6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,13 +62,7 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import ( - DynamicCache, - DynamicSlidingWindowCache, - EncoderDecoderCache, - QuantoQuantizedCache, - StaticCache, - ) + from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -1726,10 +1720,7 @@ def test_generate_continue_from_past_key_values(self): outputs_cached = model.generate(**inputs, do_sample=False, max_new_tokens=3, return_dict_in_generate=True) # Continue from the tokens generated above, preparing the inputs accordingly - if getattr(config, "sliding_window", None) is not None: - inputs["past_key_values"] = DynamicSlidingWindowCache(config.sliding_window, outputs_cached.past_key_values) - else: - inputs["past_key_values"] = outputs_cached.past_key_values + inputs["past_key_values"] = outputs_cached.past_key_values new_attention_len = outputs_cached.sequences.shape[-1] if config.is_encoder_decoder: inputs["decoder_input_ids"] = outputs_cached.sequences @@ -1819,12 +1810,7 @@ def test_new_cache_format(self, num_beams, do_sample): ) new_cache = new_results.past_key_values - if cache_cls == DynamicSlidingWindowCache: - legacy_cache_converted = cache_cls.from_legacy_cache( - config.sliding_window, legacy_results.past_key_values - ) - else: - legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) + legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) for layer_idx in range(len(new_cache)): for kv_idx in range(len(new_cache[layer_idx])): # TODO: @raushan, please look into this for new cache format From 217e803405f6f3f20c2cdd80709aeda6aedc4f53 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 16:32:53 +0200 Subject: [PATCH 18/48] define get_seen_tokens --- src/transformers/cache_utils.py | 74 +++++++++++++++++----------- src/transformers/generation/utils.py | 2 +- 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 1df6f6e36ef03b..08653aaf54a75c 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -63,6 +63,12 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") + + def get_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: + """Returns the number of already processed tokens. For all Cache classes except SlidingWindow caches, this is the same as + `get_seq_length()`. However, with sliding window we can process more tokens than the cache size. A layer index can be optionally passed. + """ + return self.get_seq_length(layer_idx) # Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length" # Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles @@ -350,8 +356,7 @@ def validate(self): class DynamicCache(Cache): """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models without sliding window attention - (see `DynamicSlidingWindowCache` in this case). + A cache that grows dynamically as more tokens are generated. This is the default for generative models. It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]`. @@ -549,7 +554,7 @@ def batch_select_indices(self, indices: torch.Tensor): class DynamicSlidingWindowCache(DynamicCache): """ A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window. - This is the default for generative models with sliding window attention (except for assisted decoding where `DynamicCache` is used). + This will be the default for generative models with sliding window attention (except for assisted decoding where `DynamicCache` is used). It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window, head_dim]` if seq_len >= sliding_window. @@ -572,9 +577,18 @@ class DynamicSlidingWindowCache(DynamicCache): ``` """ - def __init__(self, sliding_window: int): - super().__init__() + def __init__(self, sliding_window: int, num_hidden_layers: Optional[int] = None) -> None: + super().__init__(num_hidden_layers) self.sliding_window = sliding_window + # We overwrite the field and maintain a list of size `num_hidden_layers` to accurately reflect the seen tokens at each layer during `update` + self._seen_tokens = [0]*num_hidden_layers if num_hidden_layers is not None else [] + + def get_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: + """This needs to be overriden because the number of processed tokens may be larger than the cache length.""" + if len(self._seen_tokens) <= layer_idx: + return 0 + else: + return self._seen_tokens[layer_idx] def update( self, @@ -604,18 +618,23 @@ def update( Return: A tuple containing the updated key and value states. """ - # Update the number of seen tokens - if layer_idx == 0: - self._seen_tokens += key_states.shape[-2] - - # Update the cache if len(self.key_cache) <= layer_idx: + # Update the number of seen tokens + self._seen_tokens.append(key_states.shape[-2]) # Add only up to sliding window size if larger self.key_cache.append(key_states[..., -self.sliding_window :, :]) self.value_cache.append(value_states[..., -self.sliding_window :, :]) # We should return full states during prefill even though we only save up to sliding window return key_states, value_states + # In case we initialized empty lists + elif self.key_cache[layer_idx] == []: + self._seen_tokens[layer_idx] += key_states.shape[-2] + self.key_cache[layer_idx] = key_states[..., -self.sliding_window :, :] + self.value_cache[layer_idx] = value_states[..., -self.sliding_window :, :] + # We should return full states during prefill even though we only save up to sliding window + return key_states, value_states else: + self._seen_tokens[layer_idx] += key_states.shape[-2] new_seq_len = key_states.shape[-2] current_seq_len = self.get_seq_length(layer_idx) if new_seq_len + current_seq_len > self.sliding_window: @@ -637,25 +656,12 @@ def update( return self.key_cache[layer_idx], self.value_cache[layer_idx] - @classmethod - def from_legacy_cache( - cls, sliding_window: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None - ) -> "DynamicCache": - """Converts a cache in the legacy cache format into an equivalent `DynamicSlidingWindowCache`. Used for - backward compatibility.""" - cache = cls(sliding_window) - if past_key_values is not None: - for layer_idx in range(len(past_key_values)): - key_states, value_states = past_key_values[layer_idx] - cache.update(key_states, value_states, layer_idx) - return cache - - def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSlidingWindowCache"]: + def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List["DynamicCache"]: """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by `_split_model_inputs()` in `generation.utils`""" out = [] for i in range(0, full_batch_size, split_size): - current_split = DynamicSlidingWindowCache(self.sliding_window) + current_split = DynamicSlidingWindowCache(self.sliding_window, num_hidden_layers) current_split._seen_tokens = self._seen_tokens current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] @@ -663,16 +669,24 @@ def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSli return out @classmethod - def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": + def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"], num_hidden_layers: int) -> "DynamicSlidingWindowCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" - cache = cls(splits[0].sliding_window) + cache = cls(splits[0].sliding_window, num_hidden_layers) for idx in range(len(splits[0])): - layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) - layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) - cache.update(layer_keys, layer_values, idx) + key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] + value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] + if key_cache != []: + layer_keys = torch.cat(key_cache, dim=0) + layer_values = torch.cat(value_cache, dim=0) + cache.update(layer_keys, layer_values, idx) + + # We need this because _seen_tokens may be bigger than what will be automatically set with `update` (if cache > sliding_window) + cache._seen_tokens = splits[0]._seen_tokens return cache + from_legacy_cache = None + to_legacy_cache = None class OffloadedCache(DynamicCache): """ diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 919a0bcc6b794f..0f61cc84db8884 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2354,7 +2354,7 @@ def typeerror(): should_convert_cache = generation_config.return_legacy_cache is_user_defined_cache = user_defined_cache is not None is_default_cache_type = ( - type(result.past_key_values) in (DynamicCache, DynamicSlidingWindowCache) # noqa E721 + type(result.past_key_values) == DynamicCache # noqa E721 or ( isinstance(result.past_key_values, EncoderDecoderCache) and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721 From 582301cb03aef93b2a337d3cc03a6483111ea2ae Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 17:09:08 +0200 Subject: [PATCH 19/48] Modify all current .get_seq_length names --- docs/source/en/internal/generation_utils.md | 16 ++++++++-------- examples/modular-transformers/modeling_dummy.py | 4 ++-- .../modeling_my_new_model2.py | 4 ++-- examples/modular-transformers/modeling_super.py | 2 +- src/transformers/cache_utils.py | 4 ++-- src/transformers/generation/utils.py | 8 ++++---- src/transformers/models/bloom/modeling_bloom.py | 4 ++-- .../models/chameleon/modeling_chameleon.py | 4 ++-- .../models/codegen/modeling_codegen.py | 4 ++-- .../models/cohere/modeling_cohere.py | 4 ++-- src/transformers/models/dbrx/modeling_dbrx.py | 4 ++-- .../models/falcon/modeling_falcon.py | 4 ++-- src/transformers/models/gemma/modeling_gemma.py | 4 ++-- src/transformers/models/gemma/modular_gemma.py | 2 +- .../models/gemma2/modeling_gemma2.py | 2 +- src/transformers/models/gemma2/modular_gemma2.py | 2 +- src/transformers/models/git/modeling_git.py | 4 ++-- .../models/gpt_neo/modeling_gpt_neo.py | 4 ++-- .../models/gpt_neox/modeling_gpt_neox.py | 4 ++-- .../modeling_gpt_neox_japanese.py | 4 ++-- src/transformers/models/gptj/modeling_gptj.py | 4 ++-- .../models/granite/modeling_granite.py | 4 ++-- .../models/granitemoe/modeling_granitemoe.py | 4 ++-- .../models/idefics/modeling_idefics.py | 4 ++-- .../models/idefics2/modeling_idefics2.py | 7 ++++++- .../models/idefics3/modeling_idefics3.py | 7 ++++++- .../models/jetmoe/modeling_jetmoe.py | 4 ++-- src/transformers/models/llama/modeling_llama.py | 4 ++-- src/transformers/models/mimi/modeling_mimi.py | 4 ++-- .../models/mistral/modeling_mistral.py | 4 ++-- .../models/mixtral/modeling_mixtral.py | 4 ++-- .../models/mllama/modeling_mllama.py | 4 ++-- .../models/nemotron/modeling_nemotron.py | 2 +- src/transformers/models/olmo/modeling_olmo.py | 4 ++-- src/transformers/models/olmoe/modeling_olmoe.py | 4 ++-- .../models/paligemma/modeling_paligemma.py | 2 +- .../models/persimmon/modeling_persimmon.py | 4 ++-- src/transformers/models/phi/modeling_phi.py | 4 ++-- src/transformers/models/phi3/modeling_phi3.py | 4 ++-- src/transformers/models/qwen2/modeling_qwen2.py | 4 ++-- .../models/qwen2_audio/modeling_qwen2_audio.py | 2 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 4 ++-- .../models/qwen2_vl/modeling_qwen2_vl.py | 4 ++-- .../models/stablelm/modeling_stablelm.py | 4 ++-- .../models/starcoder2/modeling_starcoder2.py | 4 ++-- .../models/whisper/modeling_whisper.py | 8 ++++---- 46 files changed, 103 insertions(+), 93 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 3d4dfef7026732..3004e89300fcb3 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -362,21 +362,21 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] DynamicCache - update - - get_seq_length + - get_past_seen_tokens - reorder_cache - to_legacy_cache - from_legacy_cache [[autodoc]] DynamicSlidingWindowCache - update - - get_seq_length + - get_past_seen_tokens - reorder_cache - to_legacy_cache - from_legacy_cache [[autodoc]] QuantizedCache - update - - get_seq_length + - get_past_seen_tokens [[autodoc]] QuantoQuantizedCache @@ -384,7 +384,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] SinkCache - update - - get_seq_length + - get_past_seen_tokens - reorder_cache [[autodoc]] OffloadedCache @@ -394,17 +394,17 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] StaticCache - update - - get_seq_length + - get_past_seen_tokens - reset [[autodoc]] OffloadedStaticCache - update - - get_seq_length + - get_past_seen_tokens - reset [[autodoc]] HybridCache - update - - get_seq_length + - get_past_seen_tokens - reset [[autodoc]] SlidingWindowCache @@ -412,7 +412,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens - reset [[autodoc]] EncoderDecoderCache - - get_seq_length + - get_past_seen_tokens - to_legacy_cache - from_legacy_cache - reset diff --git a/examples/modular-transformers/modeling_dummy.py b/examples/modular-transformers/modeling_dummy.py index b5b1fc6aec85e6..420fe6d6d2c144 100644 --- a/examples/modular-transformers/modeling_dummy.py +++ b/examples/modular-transformers/modeling_dummy.py @@ -906,7 +906,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -997,7 +997,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index 49cdd274162092..8b20b43c20b3e3 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -784,7 +784,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -879,7 +879,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index d91bdb1820c2a3..71b14bb8051a97 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -902,7 +902,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 08653aaf54a75c..a7b1c1f4a29acb 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -64,7 +64,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: # TODO: deprecate this function in favor of `cache_position` raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") - def get_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: + def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: """Returns the number of already processed tokens. For all Cache classes except SlidingWindow caches, this is the same as `get_seq_length()`. However, with sliding window we can process more tokens than the cache size. A layer index can be optionally passed. """ @@ -583,7 +583,7 @@ def __init__(self, sliding_window: int, num_hidden_layers: Optional[int] = None) # We overwrite the field and maintain a list of size `num_hidden_layers` to accurately reflect the seen tokens at each layer during `update` self._seen_tokens = [0]*num_hidden_layers if num_hidden_layers is not None else [] - def get_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: + def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: """This needs to be overriden because the number of processed tokens may be larger than the cache length.""" if len(self._seen_tokens) <= layer_idx: return 0 diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 0f61cc84db8884..9a1e7db6373b04 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1533,8 +1533,8 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): past_length = 0 if not isinstance(cache, Cache): past_length = cache[0][0].shape[2] - elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None: - past_length = cache.get_seq_length() + elif hasattr(cache, "get_past_seen_tokens") and cache.get_past_seen_tokens() is not None: + past_length = cache.get_past_seen_tokens() # TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty, # end-to-end compilation will yield bad results because `cache_position` will be incorrect. @@ -2765,7 +2765,7 @@ def _contrastive_search( # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step if model_kwargs.get("past_key_values") is None or ( isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache)) - and model_kwargs["past_key_values"].get_seq_length() == 0 + and model_kwargs["past_key_values"].get_past_seen_tokens() == 0 ): # prepare inputs model_kwargs["use_cache"] = True @@ -4167,7 +4167,7 @@ def _assisted_decoding( isinstance(past_key_values, EncoderDecoderCache) and isinstance(past_key_values.self_attention_cache, DynamicCache) ): - if past_key_values.get_seq_length() == 0: + if past_key_values.get_past_seen_tokens() == 0: start_from_empty_dynamic_cache = True this_peer_finished = False diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index 75f8e5830f44bd..b0bf53901f4713 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -646,7 +646,7 @@ def forward( ) batch_size, seq_length, _ = inputs_embeds.shape - past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 seq_length_with_past = seq_length + past_length if cache_position is None: cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device) @@ -747,7 +747,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index fd76c0b1152267..f9cf103e6f5069 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -1295,7 +1295,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1386,7 +1386,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 478745b2c59ea4..2b51bad8ab7b22 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -487,7 +487,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) if position_ids is None: @@ -590,7 +590,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index a5d3721f5bdb03..69505b58147072 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -868,7 +868,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -959,7 +959,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index ef81e43d0294f0..d07bc864839dd1 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -1019,7 +1019,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1120,7 +1120,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index f48accab44bfc2..94c3930f2de7d8 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -992,7 +992,7 @@ def forward( # Compute alibi tensor: check build_alibi_tensor documentation alibi = None - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 batch_size, seq_length, _ = inputs_embeds.shape if self.use_alibi: mask = ( @@ -1114,7 +1114,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ff206a470bc3fa..c7cb92fde0ff95 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -780,7 +780,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -875,7 +875,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index 7130a30dc9be58..ce39025cfbdf41 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -860,7 +860,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 0b99aa59c65b41..be4e5b8e0c970b 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -790,7 +790,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index c0f76dbe5bfcbd..7ea5b3bdd23ee4 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -627,7 +627,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index c7f9ceafe19452..168935b467abc7 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -1283,7 +1283,7 @@ def forward( past_key_values_length = ( past_key_values[0][0].shape[2] if not isinstance(past_key_values, Cache) - else past_key_values.get_seq_length() + else past_key_values.get_past_seen_tokens() ) # Prepare head mask if needed @@ -1611,7 +1611,7 @@ def prepare_inputs_for_generation( ): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - past_length = past_key_values.get_seq_length() + past_length = past_key_values.get_past_seen_tokens() # Some generation methods already pass only the last input ID if input_ids.shape[1] > past_length: diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 7bba7608e6c187..0099341d5262c9 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -702,7 +702,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) if position_ids is None: @@ -804,7 +804,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index f4636db0a97b44..f0adac7642a7f4 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -904,7 +904,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) if position_ids is None: @@ -1001,7 +1001,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index b618f531e52f66..7215e8e05076fd 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -624,7 +624,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device) if position_ids is None: @@ -705,7 +705,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index 5c80485823c10b..8145d9b250ab61 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -774,7 +774,7 @@ def forward( seq_length = inputs_embeds.shape[1] if cache_position is None: - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) @@ -899,7 +899,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 0eb27d452f08d2..420d971b2ac4fa 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -794,7 +794,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -892,7 +892,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index ebdea826fa0450..7ac0829509a4e0 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -1020,7 +1020,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1125,7 +1125,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index 02de8d61ae204c..e757a043e68d69 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -1153,7 +1153,7 @@ def forward( ) batch_size, seq_length, _ = inputs_embeds.shape - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 seq_length_with_past = seq_length + past_key_values_length if cache_position is None: @@ -1384,7 +1384,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index b53d0722587d5a..9ca404d3cb45ed 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1359,7 +1359,7 @@ def forward( "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" ) - past_seen_tokens = past_key_values.get_seq_length() + past_seen_tokens = past_key_values.get_past_seen_tokens() if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") @@ -1664,8 +1664,13 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore +<<<<<<< HEAD past_length = past_key_values.get_seq_length() max_cache_length = past_key_values.get_max_cache_shape() +======= + past_length = past_key_values.get_past_seen_tokens() + max_cache_length = past_key_values.get_max_length() +>>>>>>> 0c098e35c (Modify all current .get_seq_length names) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 757391175ea671..f30c83e7ade3a4 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -953,7 +953,7 @@ def forward( past_seen_tokens = 0 if use_cache: - past_seen_tokens = past_key_values.get_seq_length() + past_seen_tokens = past_key_values.get_past_seen_tokens() if inputs_embeds is not None and input_ids is None and past_seen_tokens == 0: raise ValueError("When first calling the model, if input_embeds are passed, input_ids should not be None.") @@ -1252,8 +1252,13 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore +<<<<<<< HEAD past_length = past_key_values.get_seq_length() max_cache_length = past_key_values.get_max_cache_shape() +======= + past_length = past_key_values.get_past_seen_tokens() + max_cache_length = past_key_values.get_max_length() +>>>>>>> 0c098e35c (Modify all current .get_seq_length names) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index bbc70b26d1f8a9..ca2651871cbec0 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -998,7 +998,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1101,7 +1101,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index dde017bbb92797..bbcbe5ef96bb74 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -904,7 +904,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -995,7 +995,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 514f9de706ec63..49d83cf800e0da 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -958,7 +958,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=hidden_states.device ) @@ -1044,7 +1044,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index b0ffe3e56e5972..21053e36131063 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -775,7 +775,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -872,7 +872,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 9c7fadbb8f885c..2619d5792f0c90 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -986,7 +986,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1085,7 +1085,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 0bc77eaeec3324..89b9813aa4a88b 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -1620,7 +1620,7 @@ def forward( hidden_states = inputs_embeds if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1726,7 +1726,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 7d0390adc3c06f..54e5ff3698ab35 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -872,7 +872,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 7ab54146c9740b..3b1de2709f9a99 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -825,7 +825,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -914,7 +914,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 8c29f89ff3e7ea..47167a62677b5e 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -973,7 +973,7 @@ def forward( ) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1073,7 +1073,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/paligemma/modeling_paligemma.py b/src/transformers/models/paligemma/modeling_paligemma.py index d75a05bda0e1ec..e19deac906eab9 100644 --- a/src/transformers/models/paligemma/modeling_paligemma.py +++ b/src/transformers/models/paligemma/modeling_paligemma.py @@ -472,7 +472,7 @@ def forward( inputs_embeds = self.get_input_embeddings()(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 7ae3469a4c9399..88c2941ed50373 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -648,7 +648,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -741,7 +741,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 3f770c9ec00b9b..e68fb3d07de320 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -938,7 +938,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1032,7 +1032,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 0380c6cd49d6ea..225e590c58a1d4 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -965,7 +965,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1052,7 +1052,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 50f273ba766ca9..6f4b4fd61524ae 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -878,7 +878,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -971,7 +971,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 6422baac5feb5e..928433c58d531a 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -1266,7 +1266,7 @@ def prepare_inputs_for_generation( if past_key_values is not None: if isinstance(past_key_values, Cache): cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens + past_length = past_key_values.get_past_seen_tokens() else: cache_length = past_length = past_key_values[0][0].shape[2] diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 2ab13b7227ada6..4123f833090c91 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1048,7 +1048,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1152,7 +1152,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 283e38d3a7d508..52379787ab30a3 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1139,7 +1139,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1234,7 +1234,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index fe3ad6498172a9..779034f51b828f 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -923,7 +923,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -1016,7 +1016,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index e0fdbef1a3baf5..9688acc2749786 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -851,7 +851,7 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -945,7 +945,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 079965fc174a63..e860d099f9c2fe 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1248,7 +1248,7 @@ def forward( if cache_position is not None: past_key_values_length = cache_position[0] elif past_key_values is not None: - past_key_values_length = past_key_values.get_seq_length() + past_key_values_length = past_key_values.get_past_seen_tokens() if cache_position is None: cache_position = torch.arange( @@ -1383,7 +1383,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward @@ -1824,7 +1824,7 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, EncoderDecoderCache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_past_seen_tokens() else: past_length = past_key_values[0][0].shape[2] @@ -2105,7 +2105,7 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, (Cache, EncoderDecoderCache)): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_past_seen_tokens() else: past_length = past_key_values[0][0].shape[2] From b239a5782fc824bebb6ba8c333d8ed1951ee6ba9 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 17:13:38 +0200 Subject: [PATCH 20/48] style --- src/transformers/cache_utils.py | 9 ++++++--- src/transformers/generation/utils.py | 3 +-- src/transformers/models/whisper/modeling_whisper.py | 8 ++++++-- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index a7b1c1f4a29acb..c6ba9d1d55e0df 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -63,7 +63,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" # TODO: deprecate this function in favor of `cache_position` raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") - + def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: """Returns the number of already processed tokens. For all Cache classes except SlidingWindow caches, this is the same as `get_seq_length()`. However, with sliding window we can process more tokens than the cache size. A layer index can be optionally passed. @@ -581,7 +581,7 @@ def __init__(self, sliding_window: int, num_hidden_layers: Optional[int] = None) super().__init__(num_hidden_layers) self.sliding_window = sliding_window # We overwrite the field and maintain a list of size `num_hidden_layers` to accurately reflect the seen tokens at each layer during `update` - self._seen_tokens = [0]*num_hidden_layers if num_hidden_layers is not None else [] + self._seen_tokens = [0] * num_hidden_layers if num_hidden_layers is not None else [] def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: """This needs to be overriden because the number of processed tokens may be larger than the cache length.""" @@ -669,7 +669,9 @@ def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: return out @classmethod - def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"], num_hidden_layers: int) -> "DynamicSlidingWindowCache": + def from_batch_splits( + cls, splits: List["DynamicSlidingWindowCache"], num_hidden_layers: int + ) -> "DynamicSlidingWindowCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" cache = cls(splits[0].sliding_window, num_hidden_layers) @@ -688,6 +690,7 @@ def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"], num_hidden from_legacy_cache = None to_legacy_cache = None + class OffloadedCache(DynamicCache): """ A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9a1e7db6373b04..3bb8615cef4c5d 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -28,7 +28,6 @@ from ..cache_utils import ( Cache, DynamicCache, - DynamicSlidingWindowCache, EncoderDecoderCache, OffloadedCache, QuantizedCacheConfig, @@ -4605,7 +4604,7 @@ def _concat(data): return torch.cat(data, dim=0) # New cache format elif isinstance(data[0], (DynamicCache, EncoderDecoderCache)): - return data[0].__class__.from_batch_splits(data) + return data[0].__class__.from_batch_splits(data, num_hidden_layers=num_hidden_layers) elif isinstance(data[0], tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index e860d099f9c2fe..408ff54f5c5539 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1824,7 +1824,9 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, EncoderDecoderCache): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_past_seen_tokens() + past_length = ( + cache_position[0] if cache_position is not None else past_key_values.get_past_seen_tokens() + ) else: past_length = past_key_values[0][0].shape[2] @@ -2105,7 +2107,9 @@ def prepare_inputs_for_generation( past_length = 0 if past_key_values is not None: if isinstance(past_key_values, (Cache, EncoderDecoderCache)): - past_length = cache_position[0] if cache_position is not None else past_key_values.get_past_seen_tokens() + past_length = ( + cache_position[0] if cache_position is not None else past_key_values.get_past_seen_tokens() + ) else: past_length = past_key_values[0][0].shape[2] From ee30eb91db21360dea7c66b693f23cf73c420941 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 17:44:25 +0200 Subject: [PATCH 21/48] trigger CIs From f3af18023df43ff4d407aec58e0c7b66dc66ff9b Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 18:41:14 +0200 Subject: [PATCH 22/48] Add tests --- tests/generation/test_utils.py | 78 +++++++++++++++++++++++++++++++++- 1 file changed, 77 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 1727aed1117bc6..ce3a46dae3a1e9 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,7 +62,7 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import DynamicCache, DynamicSlidingWindowCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -2024,6 +2024,82 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) + @parameterized.expand([{"do_sample": False}, {'do_sample': False, 'top_k': 2, 'penalty_alpha': 0.5}]) + @pytest.mark.generate + def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dict): + """ + Tests if DynamicSlidingWindowCache works the same as DynamicCache for models that support it. The first expand + is for greedy, and the other is for contrasting search, as contrastive search needs to correctly roll back 1 token + of the cache even with DynamicSlidingWindowCache. + """ + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + if getattr(config, "sliding_window", None) is None: + self.skipTest(reason="This model does not support sliding window.") + + # Make sure we will go beyond the sliding window + config.sliding_window = 3 + model = model_class(config).to(torch_device).eval() + all_generation_kwargs = { + **generation_kwargs, + "max_new_tokens": 20, + "min_new_tokens": 20, + "use_cache": True, + } + + dynamic_cache = DynamicCache() + dynamic_sliding_cache = DynamicSlidingWindowCache(config.sliding_window) + + results_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) + results_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) + + self.assertListEqual(results_dynamic, results_sliding_dynamic) + + + @parameterized.expand([False, True]) + @pytest.mark.generate + def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_tokens_than_window: bool): + """ + Tests if we can correctly continue generation with DynamicSlidingWindowCache, even after the cache is "full" (bigger than sliding + window), and we provide more than 1 new token to add to the cache. + """ + for model_class in self.all_generative_model_classes: + config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + if getattr(config, "sliding_window", None) is None: + self.skipTest(reason="This model does not support sliding window.") + + # Make sure we will go beyond the sliding window + config.sliding_window = 3 + model = model_class(config).to(torch_device).eval() + all_generation_kwargs = { + "do_sample": False, + "max_new_tokens": 5, + "min_new_tokens": 5, + "use_cache": True, + "return_dict_in_generate": True, + } + + dynamic_cache = DynamicCache() + dynamic_sliding_cache = DynamicSlidingWindowCache(config.sliding_window) + + out_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) + out_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) + + results_dynamic, dynamic_cache = out_dynamic.sequences, out_dynamic.past_key_values + results_sliding_dynamic, dynamic_sliding_cache = out_sliding_dynamic.sequences, out_sliding_dynamic.past_key_values + + self.assertListEqual(results_dynamic, results_sliding_dynamic) + + bs = results_dynamic.shape[0] + num_added_tokens = 2 if not add_more_tokens_than_window else 4 + added_tokens = ids_tensor((bs, num_added_tokens), vocab_size=config.vocab_size) + input_ids = torch.cat([results_dynamic, added_tokens], dim=-1) + + out_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) + out_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) + + self.assertListEqual(out_dynamic.sequences, out_sliding_dynamic.sequences) + def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1): batch_size = main_input.shape[0] seq_length = main_input.shape[-1] From 25cd9c071afd6956c5c75feba817dc5231830e5d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 18:43:47 +0200 Subject: [PATCH 23/48] Update test_utils.py --- tests/generation/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ce3a46dae3a1e9..e728483739b0ff 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2053,7 +2053,7 @@ def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dic results_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) results_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) - self.assertListEqual(results_dynamic, results_sliding_dynamic) + self.assertTrue((results_dynamic ==results_sliding_dynamic).all()) @parameterized.expand([False, True]) @@ -2088,7 +2088,7 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke results_dynamic, dynamic_cache = out_dynamic.sequences, out_dynamic.past_key_values results_sliding_dynamic, dynamic_sliding_cache = out_sliding_dynamic.sequences, out_sliding_dynamic.past_key_values - self.assertListEqual(results_dynamic, results_sliding_dynamic) + self.assertTrue((results_dynamic ==results_sliding_dynamic).all()) bs = results_dynamic.shape[0] num_added_tokens = 2 if not add_more_tokens_than_window else 4 @@ -2098,7 +2098,7 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke out_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) out_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) - self.assertListEqual(out_dynamic.sequences, out_sliding_dynamic.sequences) + self.assertTrue((out_dynamic.sequences == out_sliding_dynamic.sequences).all()) def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1): batch_size = main_input.shape[0] From b2f7dee6996f652044f3870837d5f6587951d75f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 18:45:23 +0200 Subject: [PATCH 24/48] Update test_utils.py --- tests/generation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e728483739b0ff..6d3d3b63ffc253 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2056,7 +2056,7 @@ def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dic self.assertTrue((results_dynamic ==results_sliding_dynamic).all()) - @parameterized.expand([False, True]) + @parameterized.expand([(False,), (True,)]) @pytest.mark.generate def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_tokens_than_window: bool): """ From b5492900832785c96da390c2699eaab7119b776f Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 2 Oct 2024 18:49:30 +0200 Subject: [PATCH 25/48] Update test_utils.py --- tests/generation/test_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6d3d3b63ffc253..e10a741a0363dc 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2024,7 +2024,8 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) - @parameterized.expand([{"do_sample": False}, {'do_sample': False, 'top_k': 2, 'penalty_alpha': 0.5}]) + + @parameterized.expand([({"do_sample": False},), ({'do_sample': False, 'top_k': 2, 'penalty_alpha': 0.5},)]) @pytest.mark.generate def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dict): """ @@ -2053,7 +2054,7 @@ def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dic results_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) results_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) - self.assertTrue((results_dynamic ==results_sliding_dynamic).all()) + self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) @parameterized.expand([(False,), (True,)]) @@ -2088,7 +2089,7 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke results_dynamic, dynamic_cache = out_dynamic.sequences, out_dynamic.past_key_values results_sliding_dynamic, dynamic_sliding_cache = out_sliding_dynamic.sequences, out_sliding_dynamic.past_key_values - self.assertTrue((results_dynamic ==results_sliding_dynamic).all()) + self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) bs = results_dynamic.shape[0] num_added_tokens = 2 if not add_more_tokens_than_window else 4 @@ -2098,7 +2099,7 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke out_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) out_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) - self.assertTrue((out_dynamic.sequences == out_sliding_dynamic.sequences).all()) + self.assertListEqual(out_dynamic.sequences.tolist(), out_sliding_dynamic.sequences.tolist()) def _check_outputs(self, output, main_input, config, use_cache=False, num_return_sequences=1): batch_size = main_input.shape[0] From f052bede920398f947caec34f743c810c267f2e1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 3 Oct 2024 15:55:31 +0200 Subject: [PATCH 26/48] Update causal mask generation in case of DynamicSlidingCache (only Mistral) --- .../models/mistral/modeling_mistral.py | 11 ++++- tests/generation/test_utils.py | 49 ++++++++++++++----- 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 21053e36131063..5d596eaaf725f5 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -918,6 +918,15 @@ def _update_causal_mask( past_key_values=past_key_values, ) + if isinstance(past_key_values, DynamicSlidingWindowCache): + current_cache_length = past_key_values.get_seq_length() + if sequence_length + current_cache_length > self.config.sliding_window: + target_length = sequence_length + self.config.sliding_window - 1 + else: + target_length = current_cache_length + sequence_length + # Slice the causal mask to get only relevant part of the same shape as the keys/values + causal_mask = causal_mask[:, :, :, -target_length:] + if ( self.config._attn_implementation == "sdpa" and attention_mask is not None diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index e10a741a0363dc..53ac08e4a1be8b 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,7 +62,13 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, DynamicSlidingWindowCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import ( + DynamicCache, + DynamicSlidingWindowCache, + EncoderDecoderCache, + QuantoQuantizedCache, + StaticCache, + ) from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -2024,8 +2030,7 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) - - @parameterized.expand([({"do_sample": False},), ({'do_sample': False, 'top_k': 2, 'penalty_alpha': 0.5},)]) + @parameterized.expand([({"do_sample": False},), ({"do_sample": False, "top_k": 2, "penalty_alpha": 0.5},)]) @pytest.mark.generate def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dict): """ @@ -2050,12 +2055,18 @@ def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dic dynamic_cache = DynamicCache() dynamic_sliding_cache = DynamicSlidingWindowCache(config.sliding_window) - - results_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) - results_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) - self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) + results_dynamic = model.generate( + input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache + ) + results_sliding_dynamic = model.generate( + input_ids, + attention_mask=attention_mask, + **all_generation_kwargs, + past_key_values=dynamic_sliding_cache, + ) + self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) @parameterized.expand([(False,), (True,)]) @pytest.mark.generate @@ -2082,12 +2093,22 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke dynamic_cache = DynamicCache() dynamic_sliding_cache = DynamicSlidingWindowCache(config.sliding_window) - - out_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) - out_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) + + out_dynamic = model.generate( + input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache + ) + out_sliding_dynamic = model.generate( + input_ids, + attention_mask=attention_mask, + **all_generation_kwargs, + past_key_values=dynamic_sliding_cache, + ) results_dynamic, dynamic_cache = out_dynamic.sequences, out_dynamic.past_key_values - results_sliding_dynamic, dynamic_sliding_cache = out_sliding_dynamic.sequences, out_sliding_dynamic.past_key_values + results_sliding_dynamic, dynamic_sliding_cache = ( + out_sliding_dynamic.sequences, + out_sliding_dynamic.past_key_values, + ) self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) @@ -2096,8 +2117,10 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke added_tokens = ids_tensor((bs, num_added_tokens), vocab_size=config.vocab_size) input_ids = torch.cat([results_dynamic, added_tokens], dim=-1) - out_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache) - out_sliding_dynamic = model.generate(input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_sliding_cache) + out_dynamic = model.generate(input_ids, **all_generation_kwargs, past_key_values=dynamic_cache) + out_sliding_dynamic = model.generate( + input_ids, **all_generation_kwargs, past_key_values=dynamic_sliding_cache + ) self.assertListEqual(out_dynamic.sequences.tolist(), out_sliding_dynamic.sequences.tolist()) From e091f4de58306d9d88d20b1a1e13976f47ef1ca0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 3 Oct 2024 16:46:00 +0200 Subject: [PATCH 27/48] Improve tests --- src/transformers/cache_utils.py | 1 + tests/generation/test_utils.py | 33 +++++++++++++++++---------------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index c6ba9d1d55e0df..2f0becf33cb935 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -551,6 +551,7 @@ def batch_select_indices(self, indices: torch.Tensor): self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] +# TODO: (cyril) Make this the default for models with sliding window once `generate` no longer returns Cache as tuples class DynamicSlidingWindowCache(DynamicCache): """ A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window. diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 53ac08e4a1be8b..5810b0d29c44d6 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2030,7 +2030,9 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) - @parameterized.expand([({"do_sample": False},), ({"do_sample": False, "top_k": 2, "penalty_alpha": 0.5},)]) + @parameterized.expand( + [({"do_sample": False},), ({"do_sample": False, "top_k": 2, "penalty_alpha": 0.5, "low_memory": True},)] + ) @pytest.mark.generate def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dict): """ @@ -2068,20 +2070,25 @@ def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dic self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) - @parameterized.expand([(False,), (True,)]) + @parameterized.expand([(3, 1), (3, 4), (14, 5)]) @pytest.mark.generate - def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_tokens_than_window: bool): + def test_generate_continue_from_dynamic_sliding_window_cache(self, sliding_window: int, additional_tokens: int): """ - Tests if we can correctly continue generation with DynamicSlidingWindowCache, even after the cache is "full" (bigger than sliding - window), and we provide more than 1 new token to add to the cache. + Tests if we can correctly continue generation with DynamicSlidingWindowCache. + - First case tests that we can continue if the cache is already full, and we add less tokens than the sliding window + - Second case tests that we can continue if the cache is already full, and we add more tokens that the sliding window + - Third case tests that we can continue if the cache is not full, and we add tokens so that the new input is bigger than the sliding window """ for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config, _, _, _ = self._get_input_ids_and_config() if getattr(config, "sliding_window", None) is None: self.skipTest(reason="This model does not support sliding window.") + # We need to be sure to always have shape (2, 7) for the different test assumptions to hold + input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) + # Make sure we will go beyond the sliding window - config.sliding_window = 3 + config.sliding_window = sliding_window model = model_class(config).to(torch_device).eval() all_generation_kwargs = { "do_sample": False, @@ -2094,14 +2101,9 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke dynamic_cache = DynamicCache() dynamic_sliding_cache = DynamicSlidingWindowCache(config.sliding_window) - out_dynamic = model.generate( - input_ids, attention_mask=attention_mask, **all_generation_kwargs, past_key_values=dynamic_cache - ) + out_dynamic = model.generate(input_ids, **all_generation_kwargs, past_key_values=dynamic_cache) out_sliding_dynamic = model.generate( - input_ids, - attention_mask=attention_mask, - **all_generation_kwargs, - past_key_values=dynamic_sliding_cache, + input_ids, **all_generation_kwargs, past_key_values=dynamic_sliding_cache ) results_dynamic, dynamic_cache = out_dynamic.sequences, out_dynamic.past_key_values @@ -2113,8 +2115,7 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, add_more_toke self.assertListEqual(results_dynamic.tolist(), results_sliding_dynamic.tolist()) bs = results_dynamic.shape[0] - num_added_tokens = 2 if not add_more_tokens_than_window else 4 - added_tokens = ids_tensor((bs, num_added_tokens), vocab_size=config.vocab_size) + added_tokens = ids_tensor((bs, additional_tokens), vocab_size=config.vocab_size) input_ids = torch.cat([results_dynamic, added_tokens], dim=-1) out_dynamic = model.generate(input_ids, **all_generation_kwargs, past_key_values=dynamic_cache) From 9a30ad414ad3c13a0a9ac06b56fd1d671f4aaaa7 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 8 Oct 2024 18:31:44 +0200 Subject: [PATCH 28/48] improve cache --- src/transformers/cache_utils.py | 74 +++++++++++++-------------------- tests/generation/test_utils.py | 11 ++--- 2 files changed, 33 insertions(+), 52 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 2f0becf33cb935..bc6f1e5ad8b664 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -558,7 +558,11 @@ class DynamicSlidingWindowCache(DynamicCache): This will be the default for generative models with sliding window attention (except for assisted decoding where `DynamicCache` is used). It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is - `[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window, head_dim]` if seq_len >= sliding_window. + `[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window-1, head_dim]` if seq_len >= sliding_window-1. + + Note: Since we only keep maximum `sliding_window-1` tokens in the cache, once this value is reached the cache can no + longer be roll-backed to previous states without losing information. For this reason, it should not be used with assisted decoding + (or contrastive search when using `low_memory=True`). Example: @@ -578,11 +582,11 @@ class DynamicSlidingWindowCache(DynamicCache): ``` """ - def __init__(self, sliding_window: int, num_hidden_layers: Optional[int] = None) -> None: - super().__init__(num_hidden_layers) + def __init__(self, sliding_window: int) -> None: + super().__init__() self.sliding_window = sliding_window # We overwrite the field and maintain a list of size `num_hidden_layers` to accurately reflect the seen tokens at each layer during `update` - self._seen_tokens = [0] * num_hidden_layers if num_hidden_layers is not None else [] + self._seen_tokens = [] def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int: """This needs to be overriden because the number of processed tokens may be larger than the cache length.""" @@ -602,10 +606,6 @@ def update( Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Discard previous tokens according to the sliding window if needed. - Note: we always keep `sliding_window` tokens in the cache if it is full, instead of the `sliding_window - 1` tokens that - are strictly necesary. This allows to roll back one token in the past with `cache.crop(-1)` in contrastive search. - Assisted decoding would need to roll back additional tokens, and is therefore not supported with this Cache class. - Parameters: key_states (`torch.Tensor`): The new key states to cache. @@ -623,46 +623,26 @@ def update( # Update the number of seen tokens self._seen_tokens.append(key_states.shape[-2]) # Add only up to sliding window size if larger - self.key_cache.append(key_states[..., -self.sliding_window :, :]) - self.value_cache.append(value_states[..., -self.sliding_window :, :]) - # We should return full states during prefill even though we only save up to sliding window - return key_states, value_states - # In case we initialized empty lists - elif self.key_cache[layer_idx] == []: - self._seen_tokens[layer_idx] += key_states.shape[-2] - self.key_cache[layer_idx] = key_states[..., -self.sliding_window :, :] - self.value_cache[layer_idx] = value_states[..., -self.sliding_window :, :] - # We should return full states during prefill even though we only save up to sliding window + self.key_cache.append(key_states[..., -self.sliding_window+1 :, :]) + self.value_cache.append(value_states[..., -self.sliding_window+1 :, :]) + # We should return full states during prefill even though we only save up to sliding window-1 return key_states, value_states else: self._seen_tokens[layer_idx] += key_states.shape[-2] - new_seq_len = key_states.shape[-2] - current_seq_len = self.get_seq_length(layer_idx) - if new_seq_len + current_seq_len > self.sliding_window: - # We may need to return longer states (e.g. to continue generation with previous cache, with added tokens), but we only keep - # the last `sliding_window` states in the cache for next forward - full_key_states = torch.cat( - [self.key_cache[layer_idx][..., -(self.sliding_window - 1) :, :], key_states], dim=-2 - ) - full_value_states = torch.cat( - [self.value_cache[layer_idx][..., -(self.sliding_window - 1) :, :], value_states], dim=-2 - ) - self.key_cache[layer_idx] = full_key_states[..., -self.sliding_window :, :] - self.value_cache[layer_idx] = full_value_states[..., -self.sliding_window :, :] - return full_key_states, full_value_states - else: - # Similar to DynamicCache - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - - return self.key_cache[layer_idx], self.value_cache[layer_idx] - - def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: int) -> List["DynamicCache"]: + # We may need to return longer states (e.g. to continue generation with previous cache, with added tokens), but we only keep + # the last `sliding_window-1` states in the cache for next forward + full_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + full_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + self.key_cache[layer_idx] = full_key_states[..., -self.sliding_window+1 :, :] + self.value_cache[layer_idx] = full_value_states[..., -self.sliding_window+1 :, :] + return full_key_states, full_value_states + + def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by `_split_model_inputs()` in `generation.utils`""" out = [] for i in range(0, full_batch_size, split_size): - current_split = DynamicSlidingWindowCache(self.sliding_window, num_hidden_layers) + current_split = DynamicSlidingWindowCache(self.sliding_window) current_split._seen_tokens = self._seen_tokens current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] @@ -671,11 +651,10 @@ def batch_split(self, full_batch_size: int, split_size: int, num_hidden_layers: @classmethod def from_batch_splits( - cls, splits: List["DynamicSlidingWindowCache"], num_hidden_layers: int - ) -> "DynamicSlidingWindowCache": + cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" - cache = cls(splits[0].sliding_window, num_hidden_layers) + cache = cls(splits[0].sliding_window) for idx in range(len(splits[0])): key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []] @@ -687,6 +666,13 @@ def from_batch_splits( # We need this because _seen_tokens may be bigger than what will be automatically set with `update` (if cache > sliding_window) cache._seen_tokens = splits[0]._seen_tokens return cache + + def crop(self, max_length: int): + + if self.get_past_seen_tokens() >= self.sliding_window - 1: + raise RuntimeError(f"The current DynamicSlidingWindowCache is full. It cannot be cropped as this would mean losing past states.") + else: + super().crop(max_length) from_legacy_cache = None to_legacy_cache = None diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 5810b0d29c44d6..41e6f914a95e4a 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2030,15 +2030,10 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) - @parameterized.expand( - [({"do_sample": False},), ({"do_sample": False, "top_k": 2, "penalty_alpha": 0.5, "low_memory": True},)] - ) @pytest.mark.generate - def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dict): + def test_generate_with_dynamic_sliding_window_cache(self): """ - Tests if DynamicSlidingWindowCache works the same as DynamicCache for models that support it. The first expand - is for greedy, and the other is for contrasting search, as contrastive search needs to correctly roll back 1 token - of the cache even with DynamicSlidingWindowCache. + Tests if DynamicSlidingWindowCache works the same as DynamicCache for models that support it. """ for model_class in self.all_generative_model_classes: config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() @@ -2049,7 +2044,7 @@ def test_generate_with_dynamic_sliding_window_cache(self, generation_kwargs: dic config.sliding_window = 3 model = model_class(config).to(torch_device).eval() all_generation_kwargs = { - **generation_kwargs, + "do_sample": False, "max_new_tokens": 20, "min_new_tokens": 20, "use_cache": True, From 8202a19feb68b68809e4446d192af73cee682bdb Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 8 Oct 2024 18:43:24 +0200 Subject: [PATCH 29/48] add exceptions --- src/transformers/generation/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3bb8615cef4c5d..75f56bebdacabf 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -28,6 +28,7 @@ from ..cache_utils import ( Cache, DynamicCache, + DynamicSlidingWindowCache, EncoderDecoderCache, OffloadedCache, QuantizedCacheConfig, @@ -2130,6 +2131,8 @@ def generate( raise ValueError( f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}" ) + if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache): + raise ValueError("DynamicSlidingWindowCache cannot be used in assisted generation.") # 11. Get the candidate generator, given the parameterization candidate_generator = self._get_candidate_generator( @@ -2179,6 +2182,8 @@ def generate( raise ValueError( f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" ) + if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache) and getattr(model_kwargs, "low_memory", False): + raise ValueError("DynamicSlidingWindowCache cannot be used in contrastive generation with `low_memory=True`.") result = self._contrastive_search( input_ids, From 55a39a6546b75a63b353382ba61854db065b41f6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 8 Oct 2024 18:47:47 +0200 Subject: [PATCH 30/48] Update utils.py --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 75f56bebdacabf..c4ba43fa892978 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2182,7 +2182,7 @@ def generate( raise ValueError( f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" ) - if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache) and getattr(model_kwargs, "low_memory", False): + if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache) and getattr(generation_config, "low_memory", False): raise ValueError("DynamicSlidingWindowCache cannot be used in contrastive generation with `low_memory=True`.") result = self._contrastive_search( From 9caf947f7e8531e22a3839a7975359cbecc8ee22 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 8 Oct 2024 18:54:24 +0200 Subject: [PATCH 31/48] Update test_utils.py --- tests/generation/test_utils.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 41e6f914a95e4a..4070cce7d65dc1 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2030,16 +2030,26 @@ def test_inherits_generation_mixin(self): for model_class in self.all_generative_model_classes: self.assertTrue("GenerationMixin" in str(model_class.__bases__)) + @parameterized.expand([(False,), (True,)]) @pytest.mark.generate - def test_generate_with_dynamic_sliding_window_cache(self): + def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): """ Tests if DynamicSlidingWindowCache works the same as DynamicCache for models that support it. """ for model_class in self.all_generative_model_classes: - config, input_ids, attention_mask, inputs_dict = self._get_input_ids_and_config() + config, _, _, inputs_dict = self._get_input_ids_and_config() if getattr(config, "sliding_window", None) is None: self.skipTest(reason="This model does not support sliding window.") + input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) + if left_padding: + attention_mask = torch.tensor([ + [0,0,0,0,1,1,1], + [1,1,1,1,1,1,1], + ], device=input_ids.device, dtype=int) + else: + attention_mask = torch.ones_like(input_ids) + # Make sure we will go beyond the sliding window config.sliding_window = 3 model = model_class(config).to(torch_device).eval() From 1404cece8b15d232e7836b27c99e4e4ac597f622 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 8 Oct 2024 18:59:13 +0200 Subject: [PATCH 32/48] Update test_utils.py --- tests/generation/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 4070cce7d65dc1..d1c7f66b5ebd44 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2037,7 +2037,7 @@ def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): Tests if DynamicSlidingWindowCache works the same as DynamicCache for models that support it. """ for model_class in self.all_generative_model_classes: - config, _, _, inputs_dict = self._get_input_ids_and_config() + config, _ = self.prepare_config_and_inputs_for_generate() if getattr(config, "sliding_window", None) is None: self.skipTest(reason="This model does not support sliding window.") @@ -2085,7 +2085,7 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, sliding_windo - Third case tests that we can continue if the cache is not full, and we add tokens so that the new input is bigger than the sliding window """ for model_class in self.all_generative_model_classes: - config, _, _, _ = self._get_input_ids_and_config() + config, _ = self.prepare_config_and_inputs_for_generate() if getattr(config, "sliding_window", None) is None: self.skipTest(reason="This model does not support sliding window.") From 4f3ba863d8904852763b58f024e4d3b24dd7e9dc Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 8 Oct 2024 19:00:27 +0200 Subject: [PATCH 33/48] Update test_utils.py --- tests/generation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d1c7f66b5ebd44..6bae77d175bed2 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2044,7 +2044,7 @@ def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) if left_padding: attention_mask = torch.tensor([ - [0,0,0,0,1,1,1], + [0,0,0,0,0,1,1], [1,1,1,1,1,1,1], ], device=input_ids.device, dtype=int) else: From 44331f107f8c3587c069b868df43a511ffb6b754 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Tue, 8 Oct 2024 19:03:04 +0200 Subject: [PATCH 34/48] Update test_utils.py --- tests/generation/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6bae77d175bed2..756429070a8236 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2044,7 +2044,7 @@ def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) if left_padding: attention_mask = torch.tensor([ - [0,0,0,0,0,1,1], + [0,0,0,1,1,1,1], [1,1,1,1,1,1,1], ], device=input_ids.device, dtype=int) else: From b5ebae2aadd06ae85ef733061bae8470538476dc Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Wed, 9 Oct 2024 11:31:29 +0200 Subject: [PATCH 35/48] Update test_utils.py --- tests/generation/test_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 756429070a8236..fa196b128242a2 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2040,6 +2040,8 @@ def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): config, _ = self.prepare_config_and_inputs_for_generate() if getattr(config, "sliding_window", None) is None: self.skipTest(reason="This model does not support sliding window.") + if "qwen2" in str(model_class).lower(): + self.skipTest(reason="Sliding window attention is not implemented for sdpa in Qwen2 models.") input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) if left_padding: @@ -2088,6 +2090,8 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, sliding_windo config, _ = self.prepare_config_and_inputs_for_generate() if getattr(config, "sliding_window", None) is None: self.skipTest(reason="This model does not support sliding window.") + if "qwen2" in str(model_class).lower(): + self.skipTest(reason="Sliding window attention is not implemented for sdpa in Qwen2 models.") # We need to be sure to always have shape (2, 7) for the different test assumptions to hold input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) From 7e78258d5fc1fbe59001de7e33e63ad9c8dbd8f6 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 10 Oct 2024 15:41:24 +0200 Subject: [PATCH 36/48] Update 4d mask creation in Mistral --- src/transformers/cache_utils.py | 18 +++++----- src/transformers/generation/utils.py | 8 +++-- .../models/mistral/modeling_mistral.py | 34 ++++++++++++------- tests/generation/test_utils.py | 12 ++++--- 4 files changed, 44 insertions(+), 28 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index bc6f1e5ad8b664..ddc9db0d02d082 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -623,8 +623,8 @@ def update( # Update the number of seen tokens self._seen_tokens.append(key_states.shape[-2]) # Add only up to sliding window size if larger - self.key_cache.append(key_states[..., -self.sliding_window+1 :, :]) - self.value_cache.append(value_states[..., -self.sliding_window+1 :, :]) + self.key_cache.append(key_states[..., -self.sliding_window + 1 :, :]) + self.value_cache.append(value_states[..., -self.sliding_window + 1 :, :]) # We should return full states during prefill even though we only save up to sliding window-1 return key_states, value_states else: @@ -633,8 +633,8 @@ def update( # the last `sliding_window-1` states in the cache for next forward full_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) full_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) - self.key_cache[layer_idx] = full_key_states[..., -self.sliding_window+1 :, :] - self.value_cache[layer_idx] = full_value_states[..., -self.sliding_window+1 :, :] + self.key_cache[layer_idx] = full_key_states[..., -self.sliding_window + 1 :, :] + self.value_cache[layer_idx] = full_value_states[..., -self.sliding_window + 1 :, :] return full_key_states, full_value_states def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]: @@ -650,8 +650,7 @@ def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCac return out @classmethod - def from_batch_splits( - cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": + def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in `generation.utils`""" cache = cls(splits[0].sliding_window) @@ -666,11 +665,12 @@ def from_batch_splits( # We need this because _seen_tokens may be bigger than what will be automatically set with `update` (if cache > sliding_window) cache._seen_tokens = splits[0]._seen_tokens return cache - - def crop(self, max_length: int): + def crop(self, max_length: int): if self.get_past_seen_tokens() >= self.sliding_window - 1: - raise RuntimeError(f"The current DynamicSlidingWindowCache is full. It cannot be cropped as this would mean losing past states.") + raise RuntimeError( + "The current DynamicSlidingWindowCache is full. It cannot be cropped as this would mean losing past states." + ) else: super().crop(max_length) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index c4ba43fa892978..83a21e590df969 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2182,8 +2182,12 @@ def generate( raise ValueError( f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}" ) - if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache) and getattr(generation_config, "low_memory", False): - raise ValueError("DynamicSlidingWindowCache cannot be used in contrastive generation with `low_memory=True`.") + if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache) and getattr( + generation_config, "low_memory", False + ): + raise ValueError( + "DynamicSlidingWindowCache cannot be used in contrastive generation with `low_memory=True`." + ) result = self._contrastive_search( input_ids, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 5d596eaaf725f5..4cffd6baabb50a 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -904,12 +904,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -918,15 +924,6 @@ def _update_causal_mask( past_key_values=past_key_values, ) - if isinstance(past_key_values, DynamicSlidingWindowCache): - current_cache_length = past_key_values.get_seq_length() - if sequence_length + current_cache_length > self.config.sliding_window: - target_length = sequence_length + self.config.sliding_window - 1 - else: - target_length = current_cache_length + sequence_length - # Slice the causal mask to get only relevant part of the same shape as the keys/values - causal_mask = causal_mask[:, :, :, -target_length:] - if ( self.config._attn_implementation == "sdpa" and attention_mask is not None @@ -945,6 +942,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -963,6 +961,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -982,14 +983,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1000,7 +1006,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index fa196b128242a2..704b4a44b8f0f7 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2045,10 +2045,14 @@ def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) if left_padding: - attention_mask = torch.tensor([ - [0,0,0,1,1,1,1], - [1,1,1,1,1,1,1], - ], device=input_ids.device, dtype=int) + attention_mask = torch.tensor( + [ + [0, 0, 0, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], + ], + device=input_ids.device, + dtype=int, + ) else: attention_mask = torch.ones_like(input_ids) From 301f7f2d7c7ab64dc47d9b3b7b73a19e824b591d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 10 Oct 2024 15:48:47 +0200 Subject: [PATCH 37/48] fix missed conflict --- src/transformers/models/idefics2/modeling_idefics2.py | 7 +------ src/transformers/models/idefics3/modeling_idefics3.py | 7 +------ 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 9ca404d3cb45ed..6f3c2ffbbb5a4f 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -1664,13 +1664,8 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore -<<<<<<< HEAD - past_length = past_key_values.get_seq_length() - max_cache_length = past_key_values.get_max_cache_shape() -======= past_length = past_key_values.get_past_seen_tokens() - max_cache_length = past_key_values.get_max_length() ->>>>>>> 0c098e35c (Modify all current .get_seq_length names) + max_cache_length = past_key_values.get_max_cache_shape() # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index f30c83e7ade3a4..9c65b2e01b7640 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -1252,13 +1252,8 @@ def prepare_inputs_for_generation( # Omit tokens covered by past_key_values if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore -<<<<<<< HEAD - past_length = past_key_values.get_seq_length() - max_cache_length = past_key_values.get_max_cache_shape() -======= past_length = past_key_values.get_past_seen_tokens() - max_cache_length = past_key_values.get_max_length() ->>>>>>> 0c098e35c (Modify all current .get_seq_length names) + max_cache_length = past_key_values.get_max_cache_shape() # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where From be18801446b5d70083c9a9c2069c3463e3c25603 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 10 Oct 2024 15:57:55 +0200 Subject: [PATCH 38/48] Apply to other models --- src/transformers/models/mimi/modeling_mimi.py | 27 +++++++++++++---- .../models/mixtral/modeling_mixtral.py | 27 +++++++++++++---- src/transformers/models/phi3/modeling_phi3.py | 27 +++++++++++++---- .../models/phimoe/modeling_phimoe.py | 29 +++++++++++++++---- .../models/qwen2/modeling_qwen2.py | 27 +++++++++++++---- .../models/qwen2_moe/modeling_qwen2_moe.py | 27 +++++++++++++---- .../models/qwen2_vl/modeling_qwen2_vl.py | 27 +++++++++++++---- .../models/starcoder2/modeling_starcoder2.py | 27 +++++++++++++---- 8 files changed, 177 insertions(+), 41 deletions(-) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 49d83cf800e0da..985ca1fe275a81 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -23,7 +23,7 @@ from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import PreTrainedModel @@ -1076,12 +1076,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1109,6 +1115,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1127,6 +1134,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1146,14 +1156,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1164,7 +1179,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 2619d5792f0c90..8af557153dcecd 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -1117,12 +1117,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1150,6 +1156,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1168,6 +1175,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1187,14 +1197,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1205,7 +1220,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 225e590c58a1d4..aa7aff4c31f105 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -25,7 +25,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -1084,12 +1084,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1117,6 +1123,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1135,6 +1142,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1154,14 +1164,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1172,7 +1187,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index d1705f04ddb7bb..fca68092b0af3c 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -24,7 +24,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( @@ -1220,7 +1220,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1252,12 +1252,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1285,6 +1291,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1303,6 +1310,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1322,14 +1332,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1340,7 +1355,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 6f4b4fd61524ae..78dda718d8cf21 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -1003,12 +1003,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1036,6 +1042,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1054,6 +1061,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1073,14 +1083,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1091,7 +1106,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 4123f833090c91..e512f2beeab55b 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -1184,12 +1184,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1217,6 +1223,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1235,6 +1242,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1254,14 +1264,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1272,7 +1287,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 52379787ab30a3..242eaa01f3f1fd 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -30,7 +30,7 @@ from torch.nn import CrossEntropyLoss, LayerNorm from ...activations import ACT2FN -from ...cache_utils import Cache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( AttentionMaskConverter, @@ -1266,12 +1266,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1299,6 +1305,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1317,6 +1324,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1336,14 +1346,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1354,7 +1369,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 9688acc2749786..42821a3a0ea987 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -28,7 +28,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...cache_utils import Cache, DynamicCache, DynamicSlidingWindowCache, SlidingWindowCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( @@ -977,12 +977,18 @@ def _update_causal_mask( if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length + 1 ) + initial_mask_position = ( + max(0, past_seen_tokens - self.config.sliding_window + 1) + if isinstance(past_key_values, DynamicSlidingWindowCache) + else 0 + ) # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( attention_mask, sequence_length=sequence_length, target_length=target_length, + initial_mask_position=initial_mask_position, dtype=dtype, device=device, cache_position=cache_position, @@ -1010,6 +1016,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( attention_mask: torch.Tensor, sequence_length: int, target_length: int, + initial_mask_position: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, @@ -1028,6 +1035,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( The sequence length being processed. target_length (`int`): The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + initial_mask_position (`int`): + The initial mask position to use when creating the 4d mask. If the Cache does not keep all states in memory (e.g. only `sliding_window` states), + this is needed to know where to start from to create the new mask (because the new mask). dtype (`torch.dtype`): The dtype to use for the 4D attention mask. device (`torch.device`): @@ -1047,14 +1057,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( else: min_dtype = torch.finfo(dtype).min causal_mask = torch.full( - (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + (sequence_length, target_length - initial_mask_position), + fill_value=min_dtype, + dtype=dtype, + device=device, ) - diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + diagonal_attend_mask = torch.arange( + initial_mask_position, target_length, device=device + ) > cache_position.reshape(-1, 1) if config.sliding_window is not None: # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also # the check is needed to verify is current checkpoint was trained with sliding window or not if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: - sliding_attend_mask = torch.arange(target_length, device=device) <= ( + sliding_attend_mask = torch.arange(initial_mask_position, target_length, device=device) <= ( cache_position.reshape(-1, 1) - config.sliding_window ) diagonal_attend_mask |= sliding_attend_mask @@ -1065,7 +1080,9 @@ def _prepare_4d_causal_attention_mask_with_cache_position( if attention_mask.shape[-1] > target_length: attention_mask = attention_mask[:, :target_length] mask_length = attention_mask.shape[-1] - padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = ( + causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, initial_mask_position:] + ) padding_mask = padding_mask == 0 causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( padding_mask, min_dtype From 734e3fe75eb0072496a93ec2e336ae07e0045b4d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 10 Oct 2024 16:09:06 +0200 Subject: [PATCH 39/48] Add required arg in prepare_inoput --- src/transformers/models/mistral/modeling_mistral.py | 1 + src/transformers/models/mixtral/modeling_mixtral.py | 1 + src/transformers/models/phi3/modeling_phi3.py | 1 + src/transformers/models/phimoe/modeling_phimoe.py | 1 + src/transformers/models/qwen2/modeling_qwen2.py | 1 + src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 1 + src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 1 + src/transformers/models/starcoder2/modeling_starcoder2.py | 1 + 8 files changed, 8 insertions(+) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 4cffd6baabb50a..5f1cadca271607 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1195,6 +1195,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 8af557153dcecd..a78286ea570c66 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1430,6 +1430,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index aa7aff4c31f105..140fd90031e0c3 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1406,6 +1406,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index fca68092b0af3c..33dd603fdb5e7d 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -1596,6 +1596,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 78dda718d8cf21..fe8cd476965ff1 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -1296,6 +1296,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index e512f2beeab55b..d361c421fed2b1 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1500,6 +1500,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 242eaa01f3f1fd..dc0d3ae9e91e9b 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1855,6 +1855,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 42821a3a0ea987..180f560856fe15 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1272,6 +1272,7 @@ def prepare_inputs_for_generation( attention_mask, sequence_length=sequence_length, target_length=past_key_values.get_max_cache_shape(), + initial_mask_position=0, dtype=self.lm_head.weight.dtype, device=device, cache_position=cache_position, From 106c4100b151300097f131288082cb824bf83e65 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 10 Oct 2024 16:30:04 +0200 Subject: [PATCH 40/48] Update test_utils.py --- tests/generation/test_utils.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 704b4a44b8f0f7..ddf40acfe8fd47 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2038,10 +2038,10 @@ def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): """ for model_class in self.all_generative_model_classes: config, _ = self.prepare_config_and_inputs_for_generate() - if getattr(config, "sliding_window", None) is None: + if not hasattr(config, "sliding_window"): self.skipTest(reason="This model does not support sliding window.") - if "qwen2" in str(model_class).lower(): - self.skipTest(reason="Sliding window attention is not implemented for sdpa in Qwen2 models.") + if hasattr(config, "cache_implementation"): + self.skipTest(reason="This model uses a specific cache format.") input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) if left_padding: @@ -2092,10 +2092,10 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, sliding_windo """ for model_class in self.all_generative_model_classes: config, _ = self.prepare_config_and_inputs_for_generate() - if getattr(config, "sliding_window", None) is None: + if not hasattr(config, "sliding_window"): self.skipTest(reason="This model does not support sliding window.") - if "qwen2" in str(model_class).lower(): - self.skipTest(reason="Sliding window attention is not implemented for sdpa in Qwen2 models.") + if hasattr(config, "cache_implementation"): + self.skipTest(reason="This model uses a specific cache format.") # We need to be sure to always have shape (2, 7) for the different test assumptions to hold input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) From 0d8e9acf4505fa6bcf20838fd2737939fa9d8dce Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 10 Oct 2024 16:53:39 +0200 Subject: [PATCH 41/48] Update test_utils.py --- tests/generation/test_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index ddf40acfe8fd47..2aa8d154cd0e4f 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2042,6 +2042,8 @@ def test_generate_with_dynamic_sliding_window_cache(self, left_padding: bool): self.skipTest(reason="This model does not support sliding window.") if hasattr(config, "cache_implementation"): self.skipTest(reason="This model uses a specific cache format.") + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support Cache classes") input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) if left_padding: @@ -2096,6 +2098,8 @@ def test_generate_continue_from_dynamic_sliding_window_cache(self, sliding_windo self.skipTest(reason="This model does not support sliding window.") if hasattr(config, "cache_implementation"): self.skipTest(reason="This model uses a specific cache format.") + if model_class._is_stateful: + self.skipTest(reason="Stateful models don't support Cache classes") # We need to be sure to always have shape (2, 7) for the different test assumptions to hold input_ids = ids_tensor((2, 7), vocab_size=config.vocab_size) From 85090534a6d5f70745fdd8b11c356f8f2110b4a0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 10 Oct 2024 22:47:57 +0200 Subject: [PATCH 42/48] Fix kv_seq_length and rotary_seq_length --- .../models/mistral/modeling_mistral.py | 2 +- .../models/mixtral/modeling_mixtral.py | 29 ++++++++++++------- src/transformers/models/phi3/modeling_phi3.py | 26 ++--------------- .../models/qwen2/modeling_qwen2.py | 3 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 3 +- .../models/qwen2_vl/modeling_qwen2_vl.py | 7 ----- .../models/starcoder2/modeling_starcoder2.py | 3 +- 7 files changed, 29 insertions(+), 44 deletions(-) diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 5f1cadca271607..5a3dd65809f8bd 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -313,7 +313,7 @@ def forward( kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += cache_position[0] + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index a78286ea570c66..e9e795566f1fbc 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -329,6 +329,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + rotary_seq_length = kv_seq_len if past_key_value is not None: if self.layer_idx is None: raise ValueError( @@ -337,7 +338,12 @@ def forward( "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_length) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -412,6 +418,7 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + rotary_seq_length = kv_seq_len if past_key_value is not None: if self.layer_idx is None: raise ValueError( @@ -420,13 +427,12 @@ def forward( "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = ( - max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len - ) - - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_length) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -558,10 +564,13 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] + rotary_seq_length = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_length) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 140fd90031e0c3..58f4aa46f76fd2 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -377,17 +377,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -482,13 +472,7 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = ( - max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len - ) - - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids) - + cos, sin = self.rotary_emb(value_states, position_ids=position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -625,11 +609,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index fe8cd476965ff1..8c2ee4ec8828e4 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -395,7 +395,8 @@ def forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - kv_seq_len = key_states.shape[-2] + cache_position[0] + kv_seq_len = key_states.shape[-2] + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index d361c421fed2b1..a4a4d3491331d7 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -483,7 +483,8 @@ def forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - kv_seq_len = key_states.shape[-2] + cache_position[0] + kv_seq_len = key_states.shape[-2] + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index dc0d3ae9e91e9b..2346ede2c4c848 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -549,10 +549,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += cache_position[0] + 1 - if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " @@ -809,9 +805,6 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if position_embeddings is None: logger.warning_once( "The attention layers in this model are transitioning from computing the RoPE embeddings internally " diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 180f560856fe15..b7d548029bb936 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -375,7 +375,8 @@ def forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 - kv_seq_len = key_states.shape[-2] + cache_position[0] + kv_seq_len = key_states.shape[-2] + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window From 2ae645fba95861449d7162419602c35ad144b405 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 11 Oct 2024 00:44:33 +0200 Subject: [PATCH 43/48] up --- .../models/mixtral/modeling_mixtral.py | 22 +++---------------- src/transformers/models/phi3/modeling_phi3.py | 6 ++--- 2 files changed, 6 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index e9e795566f1fbc..54ecec1a78ab0b 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -329,7 +329,6 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] - rotary_seq_length = kv_seq_len if past_key_value is not None: if self.layer_idx is None: raise ValueError( @@ -338,12 +337,8 @@ def forward( "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() - else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_length) + cos, sin = self.rotary_emb(value_states, seq_len=cache_position.max() + 1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -418,7 +413,6 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] - rotary_seq_length = kv_seq_len if past_key_value is not None: if self.layer_idx is None: raise ValueError( @@ -427,12 +421,8 @@ def forward( "with a layer index." ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() - else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_length) + cos, sin = self.rotary_emb(value_states, seq_len=cache_position.max() + 1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -564,13 +554,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - rotary_seq_length = key_states.shape[-2] - if past_key_value is not None: - if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() - else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_length) + cos, sin = self.rotary_emb(value_states, seq_len=cache_position.max() + 1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 58f4aa46f76fd2..707bd1e8621b4c 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -377,7 +377,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=cache_position.max() + 1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -472,7 +472,7 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids=position_ids) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=cache_position.max() + 1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -609,7 +609,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=cache_position.max() + 1) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: From 8d539e6a87eff3804c475c44cf28d78b53e0de1d Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 11 Oct 2024 01:17:55 +0200 Subject: [PATCH 44/48] up --- .../models/mixtral/modeling_mixtral.py | 15 ++++++++++++--- src/transformers/models/phi3/modeling_phi3.py | 15 ++++++++++++--- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 54ecec1a78ab0b..415fdcac850529 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -338,7 +338,10 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=cache_position.max() + 1) + rotary_seq_len = key_states.shape[-2] + if past_key_value is not None: + rotary_seq_len += cache_position[0] + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -422,7 +425,10 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=cache_position.max() + 1) + rotary_seq_len = key_states.shape[-2] + if past_key_value is not None: + rotary_seq_len += cache_position[0] + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -554,7 +560,10 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, seq_len=cache_position.max() + 1) + rotary_seq_len = key_states.shape[-2] + if past_key_value is not None: + rotary_seq_len += cache_position[0] + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 707bd1e8621b4c..7746f418f8be96 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -377,7 +377,10 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=cache_position.max() + 1) + rotary_seq_len = key_states.shape[-2] + if past_key_value is not None: + rotary_seq_len += cache_position[0] + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -472,7 +475,10 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=cache_position.max() + 1) + rotary_seq_len = key_states.shape[-2] + if past_key_value is not None: + rotary_seq_len += cache_position[0] + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -609,7 +615,10 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=cache_position.max() + 1) + rotary_seq_len = key_states.shape[-2] + if past_key_value is not None: + rotary_seq_len += cache_position[0] + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: From e808fa53e2ba2fd17e5b3fe47b75357815cddd46 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 11 Oct 2024 01:31:31 +0200 Subject: [PATCH 45/48] up --- .../models/mixtral/modeling_mixtral.py | 36 ++++++++----------- src/transformers/models/phi3/modeling_phi3.py | 26 ++++++++------ 2 files changed, 30 insertions(+), 32 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 415fdcac850529..1b76aca214ed28 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -329,18 +329,14 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - rotary_seq_len = key_states.shape[-2] - if past_key_value is not None: - rotary_seq_len += cache_position[0] cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -416,20 +412,15 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - rotary_seq_len = key_states.shape[-2] - if past_key_value is not None: - rotary_seq_len += cache_position[0] cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -562,9 +553,12 @@ def forward( rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - rotary_seq_len += cache_position[0] - cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 7746f418f8be96..be6192cd8aae8c 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -379,8 +379,12 @@ def forward( rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - rotary_seq_len += cache_position[0] + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -466,18 +470,14 @@ def forward( value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] + rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - rotary_seq_len = key_states.shape[-2] - if past_key_value is not None: - rotary_seq_len += cache_position[0] cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -617,8 +617,12 @@ def forward( rotary_seq_len = key_states.shape[-2] if past_key_value is not None: - rotary_seq_len += cache_position[0] + if past_key_value.get_max_cache_shape() is not None: + rotary_seq_length = past_key_value.get_max_cache_shape() + else: + rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: From 8499f942158dc6cdb600b37217571b88bc5018c0 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 11 Oct 2024 01:35:08 +0200 Subject: [PATCH 46/48] up --- .../models/mixtral/modeling_mixtral.py | 12 ++++++------ src/transformers/models/phi3/modeling_phi3.py | 14 +++++++------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 1b76aca214ed28..5a1112a602144c 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -333,9 +333,9 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() + rotary_seq_len = past_key_value.get_max_cache_shape() else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -416,9 +416,9 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() + rotary_seq_len = past_key_value.get_max_cache_shape() else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -554,9 +554,9 @@ def forward( rotary_seq_len = key_states.shape[-2] if past_key_value is not None: if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() + rotary_seq_len = past_key_value.get_max_cache_shape() else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index be6192cd8aae8c..e79e91b059eef7 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -380,9 +380,9 @@ def forward( rotary_seq_len = key_states.shape[-2] if past_key_value is not None: if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() + rotary_seq_len = past_key_value.get_max_cache_shape() else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -474,9 +474,9 @@ def forward( if past_key_value is not None: kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() + rotary_seq_len = past_key_value.get_max_cache_shape() else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) @@ -618,11 +618,11 @@ def forward( rotary_seq_len = key_states.shape[-2] if past_key_value is not None: if past_key_value.get_max_cache_shape() is not None: - rotary_seq_length = past_key_value.get_max_cache_shape() + rotary_seq_len = past_key_value.get_max_cache_shape() else: - rotary_seq_length += past_key_value.get_past_seen_tokens(self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) + rotary_seq_len += past_key_value.get_past_seen_tokens(self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: From 687986696a258551add56deab329b0d89201fcef Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 11 Oct 2024 10:32:04 +0200 Subject: [PATCH 47/48] CIs From fe8a625a1c0545772fd19b193e6042d2414b9444 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 11 Oct 2024 11:51:45 +0200 Subject: [PATCH 48/48] improve sdpa is_causal escape --- src/transformers/modeling_attn_mask_utils.py | 6 +++--- src/transformers/models/mimi/modeling_mimi.py | 3 ++- src/transformers/models/mistral/modeling_mistral.py | 3 ++- src/transformers/models/mixtral/modeling_mixtral.py | 3 ++- src/transformers/models/phi3/modeling_phi3.py | 3 ++- src/transformers/models/phimoe/modeling_phimoe.py | 3 ++- src/transformers/models/qwen2/modeling_qwen2.py | 3 ++- src/transformers/models/qwen2_moe/modeling_qwen2_moe.py | 3 ++- src/transformers/models/qwen2_vl/modeling_qwen2_vl.py | 3 ++- src/transformers/models/starcoder2/modeling_starcoder2.py | 3 ++- 10 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/transformers/modeling_attn_mask_utils.py b/src/transformers/modeling_attn_mask_utils.py index 4319c021cb2bc3..64b64e88996216 100755 --- a/src/transformers/modeling_attn_mask_utils.py +++ b/src/transformers/modeling_attn_mask_utils.py @@ -275,13 +275,13 @@ def _ignore_causal_mask_sdpa( if ( (is_training or not is_tracing) and (query_length == 1 or key_value_length == query_length) - and (sliding_window is None or key_value_length < sliding_window) + and (sliding_window is None or key_value_length <= sliding_window) ): ignore_causal_mask = True - elif sliding_window is None or key_value_length < sliding_window: + elif sliding_window is None or key_value_length <= sliding_window: if len(attention_mask.shape) == 4: return False - elif not is_tracing and torch.all(attention_mask == 1): + elif not is_tracing and torch.all(attention_mask[:, -key_value_length:] == 1): if query_length == 1 or key_value_length == query_length: # For query_length == 1, causal attention and bi-directional attention are the same. ignore_causal_mask = True diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index 985ca1fe275a81..6120e2576d6866 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -1044,6 +1044,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1057,7 +1058,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 5a3dd65809f8bd..464172172ae86c 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -872,6 +872,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -885,7 +886,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 5a1112a602144c..05db01c83f3b92 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1081,6 +1081,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1094,7 +1095,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index e79e91b059eef7..3dd799f59177ff 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1045,6 +1045,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1058,7 +1059,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index 33dd603fdb5e7d..298777ecbcca7a 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -1220,6 +1220,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1233,7 +1234,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 8c2ee4ec8828e4..505a35b8b2496e 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -972,6 +972,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -985,7 +986,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index a4a4d3491331d7..2b5aff2423752d 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1153,6 +1153,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1166,7 +1167,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 2346ede2c4c848..86e6bfbb5472b9 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -1227,6 +1227,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -1240,7 +1241,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ): diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index b7d548029bb936..3e6fba51861d33 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -946,6 +946,7 @@ def _update_causal_mask( # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail # to infer the attention mask. + current_cache_length = past_key_values.get_seq_length() if past_key_values is not None else 0 past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0 using_static_cache = isinstance(past_key_values, StaticCache) using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) @@ -959,7 +960,7 @@ def _update_causal_mask( if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, - past_key_values_length=past_seen_tokens, + past_key_values_length=current_cache_length, sliding_window=self.config.sliding_window, is_training=self.training, ):