Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory usage: new dynamic cache for models supporting sliding window attention #33619

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
7 changes: 7 additions & 0 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,13 @@ A [`Constraint`] can be used to force the generation to include specific tokens
- to_legacy_cache
- from_legacy_cache

[[autodoc]] DynamicSlidingWindowCache
- update
- get_seq_length
- reorder_cache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] QuantizedCache
- update
- get_seq_length
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1252,6 +1252,7 @@
"Cache",
"CacheConfig",
"DynamicCache",
"DynamicSlidingWindowCache",
"EncoderDecoderCache",
"HQQQuantizedCache",
"HybridCache",
Expand Down Expand Up @@ -6074,6 +6075,7 @@
Cache,
CacheConfig,
DynamicCache,
DynamicSlidingWindowCache,
EncoderDecoderCache,
HQQQuantizedCache,
HybridCache,
Expand Down
123 changes: 122 additions & 1 deletion src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,8 @@ def validate(self):

class DynamicCache(Cache):
"""
A cache that grows dynamically as more tokens are generated. This is the default for generative models.
A cache that grows dynamically as more tokens are generated. This is the default for generative models without sliding window attention
(see `DynamicSlidingWindowCache` in this case).

It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]`.
Expand Down Expand Up @@ -507,6 +508,126 @@ def batch_select_indices(self, indices: torch.Tensor):
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]


class DynamicSlidingWindowCache(DynamicCache):
"""
A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window.
This is the default for generative models with sliding window attention (except for assisted decoding where `DynamicCache` is used).

It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window, head_dim]` if seq_len >= sliding_window.

Example:

```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicSlidingWindowCache

>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")

>>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt")

>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = DynamicSlidingWindowCache(model.config.sliding_window)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
DynamicSlidingWindowCache()
```
"""

def __init__(self, sliding_window: int):
super().__init__()
self.sliding_window = sliding_window

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Discard previous
tokens according to the sliding window if needed.

Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

Return:
A tuple containing the updated key and value states.
"""
# Update the number of seen tokens
if layer_idx == 0:
self._seen_tokens += key_states.shape[-2]

# Update the cache
if len(self.key_cache) <= layer_idx:
# Add only up to sliding window size if larger
self.key_cache.append(key_states[..., -self.sliding_window :, :])
self.value_cache.append(value_states[..., -self.sliding_window :, :])
# We should return full states during prefill even though we only save up to sliding window
return key_states, value_states
else:
new_seq_len = key_states.shape[-2]
current_seq_len = self.get_seq_length(layer_idx)
if new_seq_len + current_seq_len > self.sliding_window:
# We need to slice
self.key_cache[layer_idx] = torch.cat(
[self.key_cache[layer_idx][..., -(self.sliding_window - new_seq_len) :, :], key_states], dim=-2
)
self.value_cache[layer_idx] = torch.cat(
[self.value_cache[layer_idx][..., -(self.sliding_window - new_seq_len) :, :], value_states], dim=-2
)
else:
# Similar to DynamicCache
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)

return self.key_cache[layer_idx], self.value_cache[layer_idx]

@classmethod
def from_legacy_cache(
cls, sliding_window: int, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicSlidingWindowCache`. Used for
backward compatibility."""
cache = cls(sliding_window)
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
key_states, value_states = past_key_values[layer_idx]
cache.update(key_states, value_states, layer_idx)
return cache

def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicSlidingWindowCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
out = []
for i in range(0, full_batch_size, split_size):
current_split = DynamicSlidingWindowCache(self.sliding_window)
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
out.append(current_split)
return out

@classmethod
def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
cache = cls(splits[0].sliding_window)
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
cache.update(layer_keys, layer_values, idx)
return cache


class OffloadedCache(DynamicCache):
"""
A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
Expand Down
28 changes: 18 additions & 10 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..cache_utils import (
Cache,
DynamicCache,
DynamicSlidingWindowCache,
EncoderDecoderCache,
OffloadedCache,
QuantizedCacheConfig,
Expand Down Expand Up @@ -1612,11 +1613,20 @@ def _prepare_cache_for_generation(
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
else:
model_kwargs[cache_name] = (
DynamicCache()
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)
# If using sliding window attention, use specialized DynamicSlidingWindowCache. Assisted generation cannot use it because
# it needs to generate more than 1 token at a time, which is not possible with current fixed size implementation
if (
getattr(self.config, "sliding_window", None) is not None
and assistant_model is None
and not requires_cross_attention_cache
):
model_kwargs[cache_name] = DynamicSlidingWindowCache(self.config.sliding_window)
else:
model_kwargs[cache_name] = (
DynamicCache()
if not requires_cross_attention_cache
else EncoderDecoderCache(DynamicCache(), DynamicCache())
)

def _supports_num_logits_to_keep(self) -> bool:
"""
Expand Down Expand Up @@ -2207,7 +2217,7 @@ def typeerror():
should_convert_cache = generation_config.return_legacy_cache
is_user_defined_cache = user_defined_cache is not None
is_default_cache_type = (
type(result.past_key_values) == DynamicCache # noqa E721
type(result.past_key_values) in (DynamicCache, DynamicSlidingWindowCache) # noqa E721
or (
isinstance(result.past_key_values, EncoderDecoderCache)
and type(result.past_key_values.self_attention_cache) == DynamicCache # noqa E721
Expand Down Expand Up @@ -4447,10 +4457,8 @@ def _concat(data):
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
# New cache format
elif isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data)
elif isinstance(data[0], EncoderDecoderCache):
return EncoderDecoderCache.from_batch_splits(data)
elif isinstance(data[0], (DynamicCache, EncoderDecoderCache)):
return data[0].__class__.from_batch_splits(data)
elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple):
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class DynamicSlidingWindowCache(metaclass=DummyObject):
_backends = ["torch"]

def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


class EncoderDecoderCache(metaclass=DummyObject):
_backends = ["torch"]

Expand Down
16 changes: 14 additions & 2 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@
SpeechEncoderDecoderModel,
T5ForConditionalGeneration,
)
from transformers.cache_utils import DynamicCache, EncoderDecoderCache, QuantoQuantizedCache, StaticCache
from transformers.cache_utils import (
DynamicCache,
DynamicSlidingWindowCache,
EncoderDecoderCache,
QuantoQuantizedCache,
StaticCache,
)
from transformers.generation import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
Expand Down Expand Up @@ -1868,6 +1874,9 @@ def test_new_cache_format(self, num_beams, do_sample):
if config.is_encoder_decoder:
cache_cls = EncoderDecoderCache
past_key_values = cache_cls(DynamicCache(), DynamicCache())
elif getattr(config, "sliding_window", None) is not None:
cache_cls = DynamicSlidingWindowCache
past_key_values = cache_cls(config.sliding_window)
else:
cache_cls = DynamicCache
past_key_values = cache_cls()
Expand Down Expand Up @@ -1898,7 +1907,10 @@ def test_new_cache_format(self, num_beams, do_sample):
)

new_cache = new_results.past_key_values
legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
if cache_cls == DynamicSlidingWindowCache:
legacy_cache_converted = cache_cls.from_legacy_cache(config.sliding_window, legacy_results.past_key_values)
else:
legacy_cache_converted = cache_cls.from_legacy_cache(legacy_results.past_key_values)
for layer_idx in range(len(new_cache)):
for kv_idx in range(len(new_cache[layer_idx])):
self.assertTrue(
Expand Down
Loading