Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: don't show warning in forward passes when past_key_values is None #33541

Merged
merged 5 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,14 +687,18 @@ def forward(
inputs_embeds = self.word_embeddings(input_ids)

# kept for BC (non `Cache` `past_key_values` inputs)
use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"Using `past_key_values` as a tuple is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: not self.training was removed.

If we are training and we pass past_key_values as tuple of tuples, we definitely want to see the warning -- the code will break in the near future

return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please use an appropriate `Cache` class "
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bumped the deprecation to v4.47, we some key models like T5 are still missing

"(https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)

batch_size, seq_length, _ = inputs_embeds.shape
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand Down Expand Up @@ -765,9 +769,9 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()
Comment on lines +772 to +774
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy/paste from llama

(on some models, this pattern was slightly different)


if not return_dict:
return tuple(
Expand Down
22 changes: 13 additions & 9 deletions src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,14 +526,18 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)

use_legacy_cache = False
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not self.training:
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)

seq_length = inputs_embeds.shape[1]
Expand Down Expand Up @@ -608,9 +612,9 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(
Expand Down
19 changes: 11 additions & 8 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,16 +910,19 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand Down
19 changes: 11 additions & 8 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,16 +1059,19 @@ def forward(

inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)

# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit not really related to the PR but to the link which was already here before)

Linking to the Cache class is cool but you have to scroll down a bit to see an example. Would it be possible to link to a migration doc/example showcasing how a previously written code with past key values as a tuple of tuples can be adapted to be sent to the model?

The more copy-pastable the example, the less friction there will be here

Copy link
Member Author

@gante gante Sep 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LysandreJik good point!

I've added a tiny section to our cache docs about the legacy cache and how to convert it to/from the new format, with an example (cc @zucchini-nlp). This warning now points to that section in the docs.

(will merge after confirming the docs with the doc builder)

EDIT: for some reason, the doc builder is not updating its contents, despite the doc job being successful 🤔 I'm going to merge and double-check the merged results

EDIT2: it worked :) https://huggingface.co/docs/transformers/main/en/kv_cache#legacy-cache-format

)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand Down
24 changes: 14 additions & 10 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,17 +1031,21 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)

# Compute alibi tensor: check build_alibi_tensor documentation
use_legacy_cache = False
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not self.training:
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)

# Compute alibi tensor: check build_alibi_tensor documentation
alibi = None
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
batch_size, seq_length, _ = inputs_embeds.shape
Expand Down Expand Up @@ -1126,9 +1130,9 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(
Expand Down
15 changes: 11 additions & 4 deletions src/transformers/models/gemma/diff_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,12 +476,19 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False # noqa: F841
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand Down
28 changes: 13 additions & 15 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,12 +828,19 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False # noqa: F841
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True # noqa: F841
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)

if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
Expand All @@ -856,15 +863,6 @@ def forward(
# See https://github.com/huggingface/transformers/pull/29402
normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype)
hidden_states = hidden_states * normalizer
if (
use_cache and not isinstance(past_key_values, Cache) and not self.training
): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)

# decoder layers
all_hidden_states = () if output_hidden_states else None
Expand Down
27 changes: 16 additions & 11 deletions src/transformers/models/git/modeling_git.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,14 +417,19 @@ def forward(
)
use_cache = False

use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)

all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
Expand Down Expand Up @@ -463,9 +468,9 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(
Expand Down
24 changes: 14 additions & 10 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,14 +741,18 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)

use_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache) and not self.training:
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not self.training:
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)

seq_length = inputs_embeds.shape[1]
Expand Down Expand Up @@ -822,9 +826,9 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(
Expand Down
22 changes: 13 additions & 9 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -943,14 +943,18 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_in(input_ids)

use_legacy_cache = False
# kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
use_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if not self.training:
return_legacy_cache = True
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
logger.warning_once(
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.45. "
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
"will be removed in v4.47. Please use an appropriate `Cache` class "
"(https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)"
)

seq_length = inputs_embeds.shape[1]
Expand Down Expand Up @@ -1021,9 +1025,9 @@ def forward(
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attentions] if v is not None)
Expand Down
Loading
Loading