-
Notifications
You must be signed in to change notification settings - Fork 26.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache: don't show warning in forward passes when past_key_values
is None
#33541
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -687,14 +687,18 @@ def forward( | |
inputs_embeds = self.word_embeddings(input_ids) | ||
|
||
# kept for BC (non `Cache` `past_key_values` inputs) | ||
use_legacy_cache = False | ||
if use_cache and not isinstance(past_key_values, Cache) and not self.training: | ||
use_legacy_cache = True | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
logger.warning_once( | ||
"Using `past_key_values` as a tuple is deprecated and will be removed in v4.45. " | ||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" | ||
) | ||
return_legacy_cache = False | ||
if use_cache and not isinstance(past_key_values, Cache): | ||
return_legacy_cache = True | ||
if past_key_values is None: | ||
past_key_values = DynamicCache() | ||
else: | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
logger.warning_once( | ||
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " | ||
"will be removed in v4.47. Please use an appropriate `Cache` class " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bumped the deprecation to v4.47, we some key models like T5 are still missing |
||
"(https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" | ||
) | ||
|
||
batch_size, seq_length, _ = inputs_embeds.shape | ||
past_length = past_key_values.get_seq_length() if past_key_values is not None else 0 | ||
|
@@ -765,9 +769,9 @@ def forward( | |
if output_hidden_states: | ||
all_hidden_states = all_hidden_states + (hidden_states,) | ||
|
||
next_cache = None | ||
if use_cache: | ||
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache | ||
next_cache = next_decoder_cache if use_cache else None | ||
if return_legacy_cache: | ||
next_cache = next_cache.to_legacy_cache() | ||
Comment on lines
+772
to
+774
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. copy/paste from llama (on some models, this pattern was slightly different) |
||
|
||
if not return_dict: | ||
return tuple( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1059,16 +1059,19 @@ def forward( | |
|
||
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training) | ||
|
||
# kept for BC (non `Cache` `past_key_values` inputs) | ||
return_legacy_cache = False | ||
if ( | ||
use_cache and not isinstance(past_key_values, Cache) and not self.training | ||
): # kept for BC (non `Cache` `past_key_values` inputs) | ||
if use_cache and not isinstance(past_key_values, Cache): | ||
return_legacy_cache = True | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
logger.warning_once( | ||
"We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.46. " | ||
"Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" | ||
) | ||
if past_key_values is None: | ||
past_key_values = DynamicCache() | ||
else: | ||
past_key_values = DynamicCache.from_legacy_cache(past_key_values) | ||
logger.warning_once( | ||
"We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " | ||
"will be removed in v4.47. Please use an appropriate `Cache` class " | ||
"(https://huggingface.co/docs/transformers/internal/generation_utils#transformers.Cache)" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit not really related to the PR but to the link which was already here before) Linking to the The more copy-pastable the example, the less friction there will be here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @LysandreJik good point! I've added a tiny section to our cache docs about the legacy cache and how to convert it to/from the new format, with an example (cc @zucchini-nlp). This warning now points to that section in the docs. (will merge after confirming the docs with the doc builder) EDIT: for some reason, the doc builder is not updating its contents, despite the doc job being successful 🤔 I'm going to merge and double-check the merged results EDIT2: it worked :) https://huggingface.co/docs/transformers/main/en/kv_cache#legacy-cache-format |
||
) | ||
|
||
if cache_position is None: | ||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note:
not self.training
was removed.If we are training and we pass
past_key_values
as tuple of tuples, we definitely want to see the warning -- the code will break in the near future