From 6a33e28a5017fd9b094bf9f364d75732c7f66dd1 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 17:19:21 +0200 Subject: [PATCH 01/15] Add new dynamic cache --- src/transformers/cache_utils.py | 79 +++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d42b15c14abf9b..48cd052aa53c45 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -507,6 +507,85 @@ def batch_select_indices(self, indices: torch.Tensor): self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] +class DynamicSlidingWindowCache(DynamicCache): + """ + A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window. + This is the default for generative models with sliding window attention. + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window, head_dim]` if seq_len >= sliding_window. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicSlidingWindowCache + + >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + + >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicSlidingWindowCache(model.config.sliding_window) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + DynamicSlidingWindowCache() + ``` + """ + + def __init__(self, sliding_window: int): + super().__init__() + self.sliding_window = sliding_window + self.slicing_ + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Discard previous + tokens according to the sliding window if needed. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + # Add only up to sliding window size if larger + self.key_cache.append(key_states[:, :, -self.sliding_window:, :]) + self.value_cache.append(value_states[:, :, -self.sliding_window:, :]) + else: + new_seq_len = key_states.shape[-2] + current_seq_len = self.get_seq_length(layer_idx) + if new_seq_len + current_seq_len > self.sliding_window: + # We need to slice + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx][:, :, -(self.sliding_window-new_seq_len):, :], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx][:, :, -(self.sliding_window-new_seq_len):, :], value_states], dim=-2) + else: + # Similar to DynamicCache + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + class OffloadedCache(DynamicCache): """ A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. From 257f10bd25d94fcdc25a6d85dd952a3ade6bf543 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 17:46:05 +0200 Subject: [PATCH 02/15] Add cache by default in generate for models supporting it --- src/transformers/generation/utils.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2fe92d3e3ed64b..8a9370663ba765 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -28,6 +28,7 @@ from ..cache_utils import ( Cache, DynamicCache, + DynamicSlidingWindowCache, EncoderDecoderCache, OffloadedCache, QuantizedCacheConfig, @@ -1612,11 +1613,15 @@ def _prepare_cache_for_generation( # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that # keeps copying the cache thus using much more memory else: - model_kwargs[cache_name] = ( - DynamicCache() - if not requires_cross_attention_cache - else EncoderDecoderCache(DynamicCache(), DynamicCache()) - ) + # If using sliding window attention, use specialized DynamicSlidingWindowCache + if getattr(self.config, "sliding_window", None) is not None and not requires_cross_attention_cache: + model_kwargs[cache_name] = DynamicSlidingWindowCache(self.config.sliding_window) + else: + model_kwargs[cache_name] = ( + DynamicCache() + if not requires_cross_attention_cache + else EncoderDecoderCache(DynamicCache(), DynamicCache()) + ) def _supports_num_logits_to_keep(self) -> bool: """ From 80c2434b978d4536bb7c6e36d061de273e7eaf0e Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 18:02:54 +0200 Subject: [PATCH 03/15] Add to __init__ and correct typo --- src/transformers/__init__.py | 1 + src/transformers/cache_utils.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index aa13a97fe46150..eeefc412c430d8 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1252,6 +1252,7 @@ "Cache", "CacheConfig", "DynamicCache", + "DynamicSlidingWindowCache", "EncoderDecoderCache", "HQQQuantizedCache", "HybridCache", diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 48cd052aa53c45..34a75f8c497d04 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -536,7 +536,6 @@ class DynamicSlidingWindowCache(DynamicCache): def __init__(self, sliding_window: int): super().__init__() self.sliding_window = sliding_window - self.slicing_ def update( self, From f8a55531d792831001dac4e3d601888036e336ee Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 19:41:44 +0200 Subject: [PATCH 04/15] Correct output if prefill larger than sliding window + compatibility --- src/transformers/cache_utils.py | 35 ++++++++++++++++++++++++---- src/transformers/generation/utils.py | 6 ++--- 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 34a75f8c497d04..22d14ba7c1976f 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -568,15 +568,17 @@ def update( # Update the cache if len(self.key_cache) <= layer_idx: # Add only up to sliding window size if larger - self.key_cache.append(key_states[:, :, -self.sliding_window:, :]) - self.value_cache.append(value_states[:, :, -self.sliding_window:, :]) + self.key_cache.append(key_states[..., -self.sliding_window:, :]) + self.value_cache.append(value_states[..., -self.sliding_window:, :]) + # We should return full states during prefill even though we only save up to sliding window + return key_states, value_states else: new_seq_len = key_states.shape[-2] current_seq_len = self.get_seq_length(layer_idx) if new_seq_len + current_seq_len > self.sliding_window: # We need to slice - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx][:, :, -(self.sliding_window-new_seq_len):, :], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx][:, :, -(self.sliding_window-new_seq_len):, :], value_states], dim=-2) + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx][..., -(self.sliding_window-new_seq_len):, :], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx][..., -(self.sliding_window-new_seq_len):, :], value_states], dim=-2) else: # Similar to DynamicCache self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) @@ -584,6 +586,31 @@ def update( return self.key_cache[layer_idx], self.value_cache[layer_idx] + def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSlidingWindowCache"]: + """Needs to be overriden because DynamicSlidingWindowCache takes an __init__() argument.""" + out = [] + for i in range(0, full_batch_size, split_size): + current_split = DynamicSlidingWindowCache(self.sliding_window) + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": + """Needs to be overriden because DynamicSlidingWindowCache takes an __init__() argument.""" + cache = cls(splits[0].sliding_window) + for idx in range(len(splits[0])): + layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) + cache.update(layer_keys, layer_values, idx) + return cache + + # Legacy format does not really make sense here even though it is a DynamicCache -> we set methods to None + from_legacy_cache = None + to_legacy_cache = None + class OffloadedCache(DynamicCache): """ diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8a9370663ba765..74888bd51fe6bf 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -4452,10 +4452,8 @@ def _concat(data): if isinstance(data[0], torch.Tensor): return torch.cat(data, dim=0) # New cache format - elif isinstance(data[0], DynamicCache): - return DynamicCache.from_batch_splits(data) - elif isinstance(data[0], EncoderDecoderCache): - return EncoderDecoderCache.from_batch_splits(data) + elif isinstance(data[0], (DynamicCache, EncoderDecoderCache)): + return data[0].__class__.from_batch_splits(data) elif isinstance(data[0], tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): From 693808f9ad3c84022579c666d4ea3f6d16285194 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Thu, 19 Sep 2024 22:08:36 +0200 Subject: [PATCH 05/15] Add legacy format handling --- src/transformers/cache_utils.py | 21 +++++++++++++++------ src/transformers/generation/utils.py | 2 +- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 22d14ba7c1976f..6c615326c48975 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -585,9 +585,21 @@ def update( self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) return self.key_cache[layer_idx], self.value_cache[layer_idx] + + @classmethod + def from_legacy_cache(cls, sliding_window: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicSlidingWindowCache`. Used for + backward compatibility.""" + cache = cls(sliding_window) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSlidingWindowCache"]: - """Needs to be overriden because DynamicSlidingWindowCache takes an __init__() argument.""" + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" out = [] for i in range(0, full_batch_size, split_size): current_split = DynamicSlidingWindowCache(self.sliding_window) @@ -599,17 +611,14 @@ def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSli @classmethod def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": - """Needs to be overriden because DynamicSlidingWindowCache takes an __init__() argument.""" + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" cache = cls(splits[0].sliding_window) for idx in range(len(splits[0])): layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) cache.update(layer_keys, layer_values, idx) return cache - - # Legacy format does not really make sense here even though it is a DynamicCache -> we set methods to None - from_legacy_cache = None - to_legacy_cache = None class OffloadedCache(DynamicCache): diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 74888bd51fe6bf..2e7d1517c337a7 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2212,7 +2212,7 @@ def typeerror(): should_convert_cache = generation_config.return_legacy_cache is_user_defined_cache = user_defined_cache is not None is_default_cache_type = ( - type(result.past_key_values) == DynamicCache # noqa E721 + type(result.past_key_values) in (DynamicCache, DynamicSlidingWindowCache) # noqa E721 or ( isinstance(result.past_key_values, EncoderDecoderCache) and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721 From 220a54351a6eb05209449bf91b0ed5f5e685be4c Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 11:58:27 +0200 Subject: [PATCH 06/15] Update utils.py --- src/transformers/generation/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2e7d1517c337a7..8d46c1f36b10e2 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1613,8 +1613,9 @@ def _prepare_cache_for_generation( # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that # keeps copying the cache thus using much more memory else: - # If using sliding window attention, use specialized DynamicSlidingWindowCache - if getattr(self.config, "sliding_window", None) is not None and not requires_cross_attention_cache: + # If using sliding window attention, use specialized DynamicSlidingWindowCache. Assisted generation cannot use it because + # it needs to roll back the cache with is not possible with a sliding window (for more than 1 generated/discarded token at a time) + if getattr(self.config, "sliding_window", None) is not None and assistant_model is None and not requires_cross_attention_cache: model_kwargs[cache_name] = DynamicSlidingWindowCache(self.config.sliding_window) else: model_kwargs[cache_name] = ( From 1f0687d7e2898372f01c84e87a5159a67723d903 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 12:50:35 +0200 Subject: [PATCH 07/15] Update utils.py --- src/transformers/generation/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 8d46c1f36b10e2..aa41b062248066 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1614,7 +1614,7 @@ def _prepare_cache_for_generation( # keeps copying the cache thus using much more memory else: # If using sliding window attention, use specialized DynamicSlidingWindowCache. Assisted generation cannot use it because - # it needs to roll back the cache with is not possible with a sliding window (for more than 1 generated/discarded token at a time) + # it needs to generate more than 1 token at a time, which is not possible with current fixed size implementation if getattr(self.config, "sliding_window", None) is not None and assistant_model is None and not requires_cross_attention_cache: model_kwargs[cache_name] = DynamicSlidingWindowCache(self.config.sliding_window) else: From 0f05dc43eb38c378881fc258b195c246eb740434 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 13:01:41 +0200 Subject: [PATCH 08/15] style --- src/transformers/cache_utils.py | 20 +++++++++++++------- src/transformers/generation/utils.py | 6 +++++- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 6c615326c48975..b7f237285f0744 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -568,8 +568,8 @@ def update( # Update the cache if len(self.key_cache) <= layer_idx: # Add only up to sliding window size if larger - self.key_cache.append(key_states[..., -self.sliding_window:, :]) - self.value_cache.append(value_states[..., -self.sliding_window:, :]) + self.key_cache.append(key_states[..., -self.sliding_window :, :]) + self.value_cache.append(value_states[..., -self.sliding_window :, :]) # We should return full states during prefill even though we only save up to sliding window return key_states, value_states else: @@ -577,17 +577,23 @@ def update( current_seq_len = self.get_seq_length(layer_idx) if new_seq_len + current_seq_len > self.sliding_window: # We need to slice - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx][..., -(self.sliding_window-new_seq_len):, :], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx][..., -(self.sliding_window-new_seq_len):, :], value_states], dim=-2) + self.key_cache[layer_idx] = torch.cat( + [self.key_cache[layer_idx][..., -(self.sliding_window - new_seq_len) :, :], key_states], dim=-2 + ) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx][..., -(self.sliding_window - new_seq_len) :, :], value_states], dim=-2 + ) else: # Similar to DynamicCache self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) return self.key_cache[layer_idx], self.value_cache[layer_idx] - + @classmethod - def from_legacy_cache(cls, sliding_window: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": + def from_legacy_cache( + cls, sliding_window: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "DynamicCache": """Converts a cache in the legacy cache format into an equivalent `DynamicSlidingWindowCache`. Used for backward compatibility.""" cache = cls(sliding_window) @@ -608,7 +614,7 @@ def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSli current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] out.append(current_split) return out - + @classmethod def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index aa41b062248066..bce4fac659c259 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1615,7 +1615,11 @@ def _prepare_cache_for_generation( else: # If using sliding window attention, use specialized DynamicSlidingWindowCache. Assisted generation cannot use it because # it needs to generate more than 1 token at a time, which is not possible with current fixed size implementation - if getattr(self.config, "sliding_window", None) is not None and assistant_model is None and not requires_cross_attention_cache: + if ( + getattr(self.config, "sliding_window", None) is not None + and assistant_model is None + and not requires_cross_attention_cache + ): model_kwargs[cache_name] = DynamicSlidingWindowCache(self.config.sliding_window) else: model_kwargs[cache_name] = ( From fc928877fc10f0398811b75b09e7cbc9bfc6853a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 14:29:44 +0200 Subject: [PATCH 09/15] add docs --- docs/source/en/internal/generation_utils.md | 7 +++++++ src/transformers/cache_utils.py | 5 +++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index a81d202c6634af..3d4dfef7026732 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -367,6 +367,13 @@ A [`Constraint`] can be used to force the generation to include specific tokens - to_legacy_cache - from_legacy_cache +[[autodoc]] DynamicSlidingWindowCache + - update + - get_seq_length + - reorder_cache + - to_legacy_cache + - from_legacy_cache + [[autodoc]] QuantizedCache - update - get_seq_length diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index b7f237285f0744..f661bb9edda71f 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -335,7 +335,8 @@ def validate(self): class DynamicCache(Cache): """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. + A cache that grows dynamically as more tokens are generated. This is the default for generative models without sliding window attention + (see `DynamicSlidingWindowCache` in this case). It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]`. @@ -510,7 +511,7 @@ def batch_select_indices(self, indices: torch.Tensor): class DynamicSlidingWindowCache(DynamicCache): """ A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window. - This is the default for generative models with sliding window attention. + This is the default for generative models with sliding window attention (except for assisted decoding where `DynamicCache` is used). It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window, head_dim]` if seq_len >= sliding_window. From ce2b70d12e812b8c46631e4152b0ad9ce0b83e69 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 14:41:33 +0200 Subject: [PATCH 10/15] fix import --- src/transformers/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index eeefc412c430d8..87b91e228b08aa 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -6075,6 +6075,7 @@ Cache, CacheConfig, DynamicCache, + DynamicSlidingWindowCache, EncoderDecoderCache, HQQQuantizedCache, HybridCache, From 2ba750e018f8c63ff4c7f19b2d726d7c069e8bdc Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 14:49:50 +0200 Subject: [PATCH 11/15] Update dummy_pt_objects.py --- src/transformers/utils/dummy_pt_objects.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5f8ae6b5fbffac..56316a0eb91019 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -37,6 +37,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class DynamicSlidingWindowCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class EncoderDecoderCache(metaclass=DummyObject): _backends = ["torch"] From f4c3df9121d17991366d5afca8431666d102ddca Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 16:59:30 +0200 Subject: [PATCH 12/15] Update test --- tests/generation/test_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 2f8e60c79151e9..803942cb8c9841 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,7 +62,7 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import DynamicCache, DynamicSlidingWindowCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -1868,6 +1868,9 @@ def test_new_cache_format(self, num_beams, do_sample): if config.is_encoder_decoder: cache_cls = EncoderDecoderCache past_key_values = cache_cls(DynamicCache(), DynamicCache()) + elif getattr(self.config, "sliding_window", None) is not None: + cache_cls = DynamicSlidingWindowCache + past_key_values = cache_cls(self.config.sliding_window) else: cache_cls = DynamicCache past_key_values = cache_cls() From 89b56a474107759df83a44a48826cf4a6015aa3a Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 17:01:44 +0200 Subject: [PATCH 13/15] style --- tests/generation/test_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 803942cb8c9841..67e273aaa76209 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,7 +62,13 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, DynamicSlidingWindowCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import ( + DynamicCache, + DynamicSlidingWindowCache, + EncoderDecoderCache, + QuantoQuantizedCache, + StaticCache, + ) from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, From 617726c12c874d1c8acf72a245d7d8f876606a32 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 17:09:35 +0200 Subject: [PATCH 14/15] fix typo --- tests/generation/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 67e273aaa76209..6370457917fcaf 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1874,9 +1874,9 @@ def test_new_cache_format(self, num_beams, do_sample): if config.is_encoder_decoder: cache_cls = EncoderDecoderCache past_key_values = cache_cls(DynamicCache(), DynamicCache()) - elif getattr(self.config, "sliding_window", None) is not None: + elif getattr(config, "sliding_window", None) is not None: cache_cls = DynamicSlidingWindowCache - past_key_values = cache_cls(self.config.sliding_window) + past_key_values = cache_cls(config.sliding_window) else: cache_cls = DynamicCache past_key_values = cache_cls() From 4f2e6b267af34a7036476f029b5ac8c62c256a44 Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 20 Sep 2024 18:25:25 +0200 Subject: [PATCH 15/15] update cache conversion in test --- tests/generation/test_utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 6370457917fcaf..f3bffd612a602e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1907,7 +1907,10 @@ def test_new_cache_format(self, num_beams, do_sample): ) new_cache = new_results.past_key_values - legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) + if cache_cls == DynamicSlidingWindowCache: + legacy_cache_converted = cache_cls.from_legacy_cache(config.sliding_window, legacy_results.past_key_values) + else: + legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) for layer_idx in range(len(new_cache)): for kv_idx in range(len(new_cache[layer_idx])): self.assertTrue(