Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Memory usage: new dynamic cache for models supporting sliding window attention #33619

Closed
wants to merge 48 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
d894405
Add new dynamic cache
Cyrilvallez Sep 19, 2024
3b0984b
Add cache by default in generate for models supporting it
Cyrilvallez Sep 19, 2024
345e695
Add to __init__ and correct typo
Cyrilvallez Sep 19, 2024
38e82b5
Correct output if prefill larger than sliding window + compatibility
Cyrilvallez Sep 19, 2024
c46a92a
Add legacy format handling
Cyrilvallez Sep 19, 2024
02b8506
style
Cyrilvallez Sep 20, 2024
7a98aac
add docs
Cyrilvallez Sep 20, 2024
ebe6dc9
fix import
Cyrilvallez Sep 20, 2024
af95f2a
Update dummy_pt_objects.py
Cyrilvallez Sep 20, 2024
08d1a9f
Update test
Cyrilvallez Sep 20, 2024
b73655a
style
Cyrilvallez Sep 20, 2024
ff16af0
update cache conversion in test
Cyrilvallez Sep 20, 2024
5e3fef0
style
Cyrilvallez Sep 23, 2024
3d1bfd0
Allow the cache to support new states of more than 1 token, even afte…
Cyrilvallez Sep 24, 2024
6a02bdc
Update cache_utils.py
Cyrilvallez Sep 24, 2024
838712d
maybe change test
Cyrilvallez Sep 24, 2024
6afd20d
revert tests diffs
Cyrilvallez Oct 2, 2024
217e803
define get_seen_tokens
Cyrilvallez Oct 2, 2024
582301c
Modify all current .get_seq_length names
Cyrilvallez Oct 2, 2024
b239a57
style
Cyrilvallez Oct 2, 2024
ee30eb9
trigger CIs
Cyrilvallez Oct 2, 2024
f3af180
Add tests
Cyrilvallez Oct 2, 2024
25cd9c0
Update test_utils.py
Cyrilvallez Oct 2, 2024
b2f7dee
Update test_utils.py
Cyrilvallez Oct 2, 2024
b549290
Update test_utils.py
Cyrilvallez Oct 2, 2024
f052bed
Update causal mask generation in case of DynamicSlidingCache (only Mi…
Cyrilvallez Oct 3, 2024
e091f4d
Improve tests
Cyrilvallez Oct 3, 2024
9a30ad4
improve cache
Cyrilvallez Oct 8, 2024
8202a19
add exceptions
Cyrilvallez Oct 8, 2024
55a39a6
Update utils.py
Cyrilvallez Oct 8, 2024
9caf947
Update test_utils.py
Cyrilvallez Oct 8, 2024
1404cec
Update test_utils.py
Cyrilvallez Oct 8, 2024
4f3ba86
Update test_utils.py
Cyrilvallez Oct 8, 2024
44331f1
Update test_utils.py
Cyrilvallez Oct 8, 2024
b5ebae2
Update test_utils.py
Cyrilvallez Oct 9, 2024
7e78258
Update 4d mask creation in Mistral
Cyrilvallez Oct 10, 2024
301f7f2
fix missed conflict
Cyrilvallez Oct 10, 2024
be18801
Apply to other models
Cyrilvallez Oct 10, 2024
734e3fe
Add required arg in prepare_inoput
Cyrilvallez Oct 10, 2024
106c410
Update test_utils.py
Cyrilvallez Oct 10, 2024
0d8e9ac
Update test_utils.py
Cyrilvallez Oct 10, 2024
8509053
Fix kv_seq_length and rotary_seq_length
Cyrilvallez Oct 10, 2024
2ae645f
up
Cyrilvallez Oct 10, 2024
8d539e6
up
Cyrilvallez Oct 10, 2024
e808fa5
up
Cyrilvallez Oct 10, 2024
8499f94
up
Cyrilvallez Oct 10, 2024
6879866
CIs
Cyrilvallez Oct 11, 2024
fe8a625
improve sdpa is_causal escape
Cyrilvallez Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions docs/source/en/internal/generation_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,22 +362,29 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] DynamicCache
- update
- get_seq_length
- get_past_seen_tokens
- reorder_cache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] DynamicSlidingWindowCache
- update
- get_past_seen_tokens
- reorder_cache
- to_legacy_cache
- from_legacy_cache

[[autodoc]] QuantizedCache
- update
- get_seq_length
- get_past_seen_tokens

[[autodoc]] QuantoQuantizedCache

[[autodoc]] HQQQuantizedCache

[[autodoc]] SinkCache
- update
- get_seq_length
- get_past_seen_tokens
- reorder_cache

[[autodoc]] OffloadedCache
Expand All @@ -387,25 +394,25 @@ A [`Constraint`] can be used to force the generation to include specific tokens

[[autodoc]] StaticCache
- update
- get_seq_length
- get_past_seen_tokens
- reset

[[autodoc]] OffloadedStaticCache
- update
- get_seq_length
- get_past_seen_tokens
- reset

[[autodoc]] HybridCache
- update
- get_seq_length
- get_past_seen_tokens
- reset

[[autodoc]] SlidingWindowCache
- update
- reset

[[autodoc]] EncoderDecoderCache
- get_seq_length
- get_past_seen_tokens
- to_legacy_cache
- from_legacy_cache
- reset
Expand Down
4 changes: 2 additions & 2 deletions examples/modular-transformers/modeling_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,7 @@ def forward(
)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
Expand Down Expand Up @@ -997,7 +997,7 @@ def _update_causal_mask(
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
Expand Down
4 changes: 2 additions & 2 deletions examples/modular-transformers/modeling_my_new_model2.py
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ def forward(
)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
Expand Down Expand Up @@ -879,7 +879,7 @@ def _update_causal_mask(
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
Expand Down
2 changes: 1 addition & 1 deletion examples/modular-transformers/modeling_super.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ def _update_causal_mask(
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,7 @@
"Cache",
"CacheConfig",
"DynamicCache",
"DynamicSlidingWindowCache",
"EncoderDecoderCache",
"HQQQuantizedCache",
"HybridCache",
Expand Down Expand Up @@ -6156,6 +6157,7 @@
Cache,
CacheConfig,
DynamicCache,
DynamicSlidingWindowCache,
EncoderDecoderCache,
HQQQuantizedCache,
HybridCache,
Expand Down
133 changes: 133 additions & 0 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# TODO: deprecate this function in favor of `cache_position`
raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.")

def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the number of already processed tokens. For all Cache classes except SlidingWindow caches, this is the same as
`get_seq_length()`. However, with sliding window we can process more tokens than the cache size. A layer index can be optionally passed.
"""
return self.get_seq_length(layer_idx)

# Deprecate in favor of max-cache-shape because we want to be specifc by what we mean with "max_length"
# Prev some cache objects didn't have "max_length" (SlidingWindowCache or SinkCache) because the cache object technically handles
# infinite amount of tokens. In the codebase what we really need to check is the max capacity of certain cache instances, so
Expand Down Expand Up @@ -545,6 +551,133 @@ def batch_select_indices(self, indices: torch.Tensor):
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]


# TODO: (cyril) Make this the default for models with sliding window once `generate` no longer returns Cache as tuples
class DynamicSlidingWindowCache(DynamicCache):
"""
A cache that grows dynamically as more tokens are generated, but will stop growing if the sequence length is bigger than the sliding window.
This will be the default for generative models with sliding window attention (except for assisted decoding where `DynamicCache` is used).

It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is
`[batch_size, num_heads, seq_len, head_dim]` and up to `[batch_size, num_heads, sliding_window-1, head_dim]` if seq_len >= sliding_window-1.

Note: Since we only keep maximum `sliding_window-1` tokens in the cache, once this value is reached the cache can no
longer be roll-backed to previous states without losing information. For this reason, it should not be used with assisted decoding
(or contrastive search when using `low_memory=True`).

Example:

```python
>>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicSlidingWindowCache

>>> model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.1")

>>> inputs = tokenizer(text="My name is Mistral", return_tensors="pt")

>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = DynamicSlidingWindowCache(model.config.sliding_window)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
DynamicSlidingWindowCache()
```
"""

def __init__(self, sliding_window: int) -> None:
super().__init__()
self.sliding_window = sliding_window
# We overwrite the field and maintain a list of size `num_hidden_layers` to accurately reflect the seen tokens at each layer during `update`
self._seen_tokens = []

def get_past_seen_tokens(self, layer_idx: Optional[int] = 0) -> int:
"""This needs to be overriden because the number of processed tokens may be larger than the cache length."""
if len(self._seen_tokens) <= layer_idx:
return 0
else:
return self._seen_tokens[layer_idx]

def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Discard previous
tokens according to the sliding window if needed.

Parameters:
key_states (`torch.Tensor`):
The new key states to cache.
value_states (`torch.Tensor`):
The new value states to cache.
layer_idx (`int`):
The index of the layer to cache the states for.
cache_kwargs (`Dict[str, Any]`, `optional`):
Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`.

Return:
A tuple containing the updated key and value states.
"""
if len(self.key_cache) <= layer_idx:
# Update the number of seen tokens
self._seen_tokens.append(key_states.shape[-2])
# Add only up to sliding window size if larger
self.key_cache.append(key_states[..., -self.sliding_window + 1 :, :])
self.value_cache.append(value_states[..., -self.sliding_window + 1 :, :])
# We should return full states during prefill even though we only save up to sliding window-1
return key_states, value_states
else:
self._seen_tokens[layer_idx] += key_states.shape[-2]
# We may need to return longer states (e.g. to continue generation with previous cache, with added tokens), but we only keep
# the last `sliding_window-1` states in the cache for next forward
full_key_states = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
full_value_states = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
self.key_cache[layer_idx] = full_key_states[..., -self.sliding_window + 1 :, :]
self.value_cache[layer_idx] = full_value_states[..., -self.sliding_window + 1 :, :]
return full_key_states, full_value_states

def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
out = []
for i in range(0, full_batch_size, split_size):
current_split = DynamicSlidingWindowCache(self.sliding_window)
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
out.append(current_split)
return out

@classmethod
def from_batch_splits(cls, splits: List["DynamicSlidingWindowCache"]) -> "DynamicSlidingWindowCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
cache = cls(splits[0].sliding_window)
for idx in range(len(splits[0])):
key_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
value_cache = [current.key_cache[idx] for current in splits if current.key_cache[idx] != []]
if key_cache != []:
layer_keys = torch.cat(key_cache, dim=0)
layer_values = torch.cat(value_cache, dim=0)
cache.update(layer_keys, layer_values, idx)

# We need this because _seen_tokens may be bigger than what will be automatically set with `update` (if cache > sliding_window)
cache._seen_tokens = splits[0]._seen_tokens
return cache

def crop(self, max_length: int):
if self.get_past_seen_tokens() >= self.sliding_window - 1:
raise RuntimeError(
"The current DynamicSlidingWindowCache is full. It cannot be cropped as this would mean losing past states."
)
else:
super().crop(max_length)

from_legacy_cache = None
to_legacy_cache = None


class OffloadedCache(DynamicCache):
"""
A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
Expand Down
23 changes: 15 additions & 8 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..cache_utils import (
Cache,
DynamicCache,
DynamicSlidingWindowCache,
EncoderDecoderCache,
OffloadedCache,
QuantizedCacheConfig,
Expand Down Expand Up @@ -1532,8 +1533,8 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
past_length = 0
if not isinstance(cache, Cache):
past_length = cache[0][0].shape[2]
elif hasattr(cache, "get_seq_length") and cache.get_seq_length() is not None:
past_length = cache.get_seq_length()
elif hasattr(cache, "get_past_seen_tokens") and cache.get_past_seen_tokens() is not None:
past_length = cache.get_past_seen_tokens()

# TODO(joao): this is not torch.compile-friendly, find a work-around. If the cache is not empty,
# end-to-end compilation will yield bad results because `cache_position` will be incorrect.
Expand Down Expand Up @@ -2130,6 +2131,8 @@ def generate(
raise ValueError(
f"assisted generation is not supported with stateful models, such as {self.__class__.__name__}"
)
if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache):
raise ValueError("DynamicSlidingWindowCache cannot be used in assisted generation.")

# 11. Get the candidate generator, given the parameterization
candidate_generator = self._get_candidate_generator(
Expand Down Expand Up @@ -2179,6 +2182,12 @@ def generate(
raise ValueError(
f"contrastive search is not supported with stateful models, such as {self.__class__.__name__}"
)
if isinstance(getattr(model_kwargs, "past_key_values", None), DynamicSlidingWindowCache) and getattr(
generation_config, "low_memory", False
):
raise ValueError(
"DynamicSlidingWindowCache cannot be used in contrastive generation with `low_memory=True`."
)

result = self._contrastive_search(
input_ids,
Expand Down Expand Up @@ -2764,7 +2773,7 @@ def _contrastive_search(
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past_key_values") is None or (
isinstance(model_kwargs["past_key_values"], (Cache, EncoderDecoderCache))
and model_kwargs["past_key_values"].get_seq_length() == 0
and model_kwargs["past_key_values"].get_past_seen_tokens() == 0
):
# prepare inputs
model_kwargs["use_cache"] = True
Expand Down Expand Up @@ -4166,7 +4175,7 @@ def _assisted_decoding(
isinstance(past_key_values, EncoderDecoderCache)
and isinstance(past_key_values.self_attention_cache, DynamicCache)
):
if past_key_values.get_seq_length() == 0:
if past_key_values.get_past_seen_tokens() == 0:
start_from_empty_dynamic_cache = True

this_peer_finished = False
Expand Down Expand Up @@ -4603,10 +4612,8 @@ def _concat(data):
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
# New cache format
elif isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], EncoderDecoderCache):
return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], (DynamicCache, EncoderDecoderCache)):
return data[0].__class__.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple):
Expand Down
6 changes: 3 additions & 3 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,13 +275,13 @@ def _ignore_causal_mask_sdpa(
if (
(is_training or not is_tracing)
and (query_length == 1 or key_value_length == query_length)
and (sliding_window is None or key_value_length < sliding_window)
and (sliding_window is None or key_value_length <= sliding_window)
):
ignore_causal_mask = True
elif sliding_window is None or key_value_length < sliding_window:
elif sliding_window is None or key_value_length <= sliding_window:
if len(attention_mask.shape) == 4:
return False
elif not is_tracing and torch.all(attention_mask == 1):
elif not is_tracing and torch.all(attention_mask[:, -key_value_length:] == 1):
if query_length == 1 or key_value_length == query_length:
# For query_length == 1, causal attention and bi-directional attention are the same.
ignore_causal_mask = True
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -646,7 +646,7 @@ def forward(
)

batch_size, seq_length, _ = inputs_embeds.shape
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
past_length = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
seq_length_with_past = seq_length + past_length
if cache_position is None:
cache_position = torch.arange(past_length, past_length + seq_length, device=inputs_embeds.device)
Expand Down Expand Up @@ -747,7 +747,7 @@ def _update_causal_mask(
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
past_seen_tokens = past_key_values.get_past_seen_tokens() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)

# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
Expand Down
Loading
Loading