diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index a81d202c6634af..3d4dfef7026732 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -367,6 +367,13 @@ A [`Constraint`] can be used to force the generation to include specific tokens - to_legacy_cache - from_legacy_cache +[[autodoc]] DynamicSlidingWindowCache + - update + - get_seq_length + - reorder_cache + - to_legacy_cache + - from_legacy_cache + [[autodoc]] QuantizedCache - update - get_seq_length diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index aa13a97fe46150..87b91e228b08aa 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -1252,6 +1252,7 @@ "Cache", "CacheConfig", "DynamicCache", + "DynamicSlidingWindowCache", "EncoderDecoderCache", "HQQQuantizedCache", "HybridCache", @@ -6074,6 +6075,7 @@ Cache, CacheConfig, DynamicCache, + DynamicSlidingWindowCache, EncoderDecoderCache, HQQQuantizedCache, HybridCache, diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d42b15c14abf9b..f661bb9edda71f 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -335,7 +335,8 @@ def validate(self): class DynamicCache(Cache): """ - A cache that grows dynamically as more tokens are generated. This is the default for generative models. + A cache that grows dynamically as more tokens are generated. This is the default for generative models without sliding window attention + (see `DynamicSlidingWindowCache` in this case). It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is `[batch_size, num_heads, seq_len, head_dim]`. @@ -507,6 +508,126 @@ def batch_select_indices(self, indices: torch.Tensor): self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...] +class DynamicSlidingWindowCache(DynamicCache): + """ + A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window. + This is the default for generative models with sliding window attention (except for assisted decoding where `DynamicCache` is used). + + It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is + `[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window, head_dim]` if seq_len >= sliding_window. + + Example: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicSlidingWindowCache + + >>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1") + + >>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt") + + >>> # Prepare a cache class and pass it to model's forward + >>> past_key_values = DynamicSlidingWindowCache(model.config.sliding_window) + >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) + >>> outputs.past_key_values # access cache filled with key/values from generation + DynamicSlidingWindowCache() + ``` + """ + + def __init__(self, sliding_window: int): + super().__init__() + self.sliding_window = sliding_window + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Discard previous + tokens according to the sliding window if needed. + + Parameters: + key_states (`torch.Tensor`): + The new key states to cache. + value_states (`torch.Tensor`): + The new value states to cache. + layer_idx (`int`): + The index of the layer to cache the states for. + cache_kwargs (`Dict[str, Any]`, `optional`): + Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. + + Return: + A tuple containing the updated key and value states. + """ + # Update the number of seen tokens + if layer_idx == 0: + self._seen_tokens += key_states.shape[-2] + + # Update the cache + if len(self.key_cache) <= layer_idx: + # Add only up to sliding window size if larger + self.key_cache.append(key_states[..., -self.sliding_window :, :]) + self.value_cache.append(value_states[..., -self.sliding_window :, :]) + # We should return full states during prefill even though we only save up to sliding window + return key_states, value_states + else: + new_seq_len = key_states.shape[-2] + current_seq_len = self.get_seq_length(layer_idx) + if new_seq_len + current_seq_len > self.sliding_window: + # We need to slice + self.key_cache[layer_idx] = torch.cat( + [self.key_cache[layer_idx][..., -(self.sliding_window - new_seq_len) :, :], key_states], dim=-2 + ) + self.value_cache[layer_idx] = torch.cat( + [self.value_cache[layer_idx][..., -(self.sliding_window - new_seq_len) :, :], value_states], dim=-2 + ) + else: + # Similar to DynamicCache + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + + return self.key_cache[layer_idx], self.value_cache[layer_idx] + + @classmethod + def from_legacy_cache( + cls, sliding_window: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "DynamicCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicSlidingWindowCache`. Used for + backward compatibility.""" + cache = cls(sliding_window) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSlidingWindowCache"]: + """Split the current instance into a list of `DynamicCache` by the batch size. This will be used by + `_split_model_inputs()` in `generation.utils`""" + out = [] + for i in range(0, full_batch_size, split_size): + current_split = DynamicSlidingWindowCache(self.sliding_window) + current_split._seen_tokens = self._seen_tokens + current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache] + current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache] + out.append(current_split) + return out + + @classmethod + def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache": + """This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in + `generation.utils`""" + cache = cls(splits[0].sliding_window) + for idx in range(len(splits[0])): + layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0) + layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0) + cache.update(layer_keys, layer_values, idx) + return cache + + class OffloadedCache(DynamicCache): """ A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory. diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 2fe92d3e3ed64b..bce4fac659c259 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -28,6 +28,7 @@ from ..cache_utils import ( Cache, DynamicCache, + DynamicSlidingWindowCache, EncoderDecoderCache, OffloadedCache, QuantizedCacheConfig, @@ -1612,11 +1613,20 @@ def _prepare_cache_for_generation( # Use DynamicCache() instance by default. This will avoid back and forth from legacy format that # keeps copying the cache thus using much more memory else: - model_kwargs[cache_name] = ( - DynamicCache() - if not requires_cross_attention_cache - else EncoderDecoderCache(DynamicCache(), DynamicCache()) - ) + # If using sliding window attention, use specialized DynamicSlidingWindowCache. Assisted generation cannot use it because + # it needs to generate more than 1 token at a time, which is not possible with current fixed size implementation + if ( + getattr(self.config, "sliding_window", None) is not None + and assistant_model is None + and not requires_cross_attention_cache + ): + model_kwargs[cache_name] = DynamicSlidingWindowCache(self.config.sliding_window) + else: + model_kwargs[cache_name] = ( + DynamicCache() + if not requires_cross_attention_cache + else EncoderDecoderCache(DynamicCache(), DynamicCache()) + ) def _supports_num_logits_to_keep(self) -> bool: """ @@ -2207,7 +2217,7 @@ def typeerror(): should_convert_cache = generation_config.return_legacy_cache is_user_defined_cache = user_defined_cache is not None is_default_cache_type = ( - type(result.past_key_values) == DynamicCache # noqa E721 + type(result.past_key_values) in (DynamicCache, DynamicSlidingWindowCache) # noqa E721 or ( isinstance(result.past_key_values, EncoderDecoderCache) and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721 @@ -4447,10 +4457,8 @@ def _concat(data): if isinstance(data[0], torch.Tensor): return torch.cat(data, dim=0) # New cache format - elif isinstance(data[0], DynamicCache): - return DynamicCache.from_batch_splits(data) - elif isinstance(data[0], EncoderDecoderCache): - return EncoderDecoderCache.from_batch_splits(data) + elif isinstance(data[0], (DynamicCache, EncoderDecoderCache)): + return data[0].__class__.from_batch_splits(data) elif isinstance(data[0], tuple): # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) if isinstance(data[0][0], tuple): diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5f8ae6b5fbffac..56316a0eb91019 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -37,6 +37,13 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +class DynamicSlidingWindowCache(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + class EncoderDecoderCache(metaclass=DummyObject): _backends = ["torch"] diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 2f8e60c79151e9..f3bffd612a602e 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -62,7 +62,13 @@ SpeechEncoderDecoderModel, T5ForConditionalGeneration, ) - from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache + from transformers.cache_utils import ( + DynamicCache, + DynamicSlidingWindowCache, + EncoderDecoderCache, + QuantoQuantizedCache, + StaticCache, + ) from transformers.generation import ( BeamSampleDecoderOnlyOutput, BeamSampleEncoderDecoderOutput, @@ -1868,6 +1874,9 @@ def test_new_cache_format(self, num_beams, do_sample): if config.is_encoder_decoder: cache_cls = EncoderDecoderCache past_key_values = cache_cls(DynamicCache(), DynamicCache()) + elif getattr(config, "sliding_window", None) is not None: + cache_cls = DynamicSlidingWindowCache + past_key_values = cache_cls(config.sliding_window) else: cache_cls = DynamicCache past_key_values = cache_cls() @@ -1898,7 +1907,10 @@ def test_new_cache_format(self, num_beams, do_sample): ) new_cache = new_results.past_key_values - legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) + if cache_cls == DynamicSlidingWindowCache: + legacy_cache_converted = cache_cls.from_legacy_cache(config.sliding_window, legacy_results.past_key_values) + else: + legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values) for layer_idx in range(len(new_cache)): for kv_idx in range(len(new_cache[layer_idx])): self.assertTrue(