Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: don't throw warnings on gemma2 when instantiating a new cache #33595

Merged
merged 3 commits into from
Sep 19, 2024

Conversation

gante
Copy link
Member

@gante gante commented Sep 19, 2024

What does this PR do?

Related to #33541

The warning in question should only be thrown in the case we are converting from a legacy cache, which will be deprecated soon. Gemma 2 doesn't support the legacy cache format, so no warning should ever be thrown :)

In the process, updates a few related inconsistencies.


✅ slow gemma2 tests ran locally. There are a few failures (also present on main). Some failures were fixed in this PR.

@gante gante changed the title Cache: don't throw warnings on gemma 2 when instantiating a new cache Cache: don't throw warnings on gemma2 when instantiating a new cache Sep 19, 2024
Comment on lines 1662 to 1666
def get_seq_length(self, layer_idx: Optional[int] = 0):
return None
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
# limit the check to the first batch member and head dimension.
# TODO: deprecate this function in favor of `cache_position`
return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HybridCache is a StaticCache with alternating sliding window layers. The method to retrieve the cache length is copy/paste from StaticCache

We will want to use another method in the future, but let's leave this as a copy of StaticCache for now. This method is needed in the updated gemma 2.

raise ValueError("When `past_key_values` is passed, `cache_position` must be too")

# Probably a forward call with caching, so we set up cache for one call only
if use_cache and past_key_values is None and not self.training:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two changes here, both to be consistent with other models:

  1. self.training should not control whether we instantiate a cache
  2. If a user respects the types in the docs, past_key_values is either a Cache or we instantiate a new one for the user without warnings

@@ -840,6 +822,11 @@ def forward(
dtype=inputs_embeds.dtype,
)

if cache_position is None:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy/paste from llama (and other Cache-supporting models)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey, this should always work actually since the seq length gets layer_idx=0. Just one question, isn't it a bit misleading if some layers will have get_seq_length() number of tokens while others no more than sliding window length?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zucchini-nlp yes, if get_seq_length gets called on the wrong layer we will have problems! I'm going to add an exception if it gets called on layer_idx != 0 (I doubt we need it).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey sounds good, as long as the function of get_seq_length is transparent for users, to reduce number of cache-related question we get 😄

@@ -1000,8 +1000,16 @@ def forward(
)
use_cache = False

if use_cache and past_key_values is None and not self.training:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if use_cache and not isinstance(past_key_values, Cache):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy/paste from llama (and other Cache-supporting models)

@@ -86,10 +86,15 @@ def setUp(self):
def test_model_outputs_equivalence(self, **kwargs):
pass

@parameterized.expand([("float16",), ("bfloat16",), ("float32",)])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without this parameterized, the intended overwriting was not happening

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Please merge once @zucchini-nlp has approved as she knows this code more than I.

cc @BenjaminBossan as well

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for cleaning up warnings! Left one question about HybridCache, since I was reluctant to add seq-length for that cache type where lengths are not consistent over layers

@@ -840,6 +822,11 @@ def forward(
dtype=inputs_embeds.dtype,
)

if cache_position is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okey, this should always work actually since the seq length gets layer_idx=0. Just one question, isn't it a bit misleading if some layers will have get_seq_length() number of tokens while others no more than sliding window length?

@BenjaminBossan
Copy link
Member

I'm not qualified to review this but thanks for addressing this so quickly.

@gante gante merged commit 52920b5 into huggingface:main Sep 19, 2024
23 checks passed
@gante gante deleted the gemma2_warning branch September 19, 2024 16:42
itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
amyeroberts pushed a commit to amyeroberts/transformers that referenced this pull request Oct 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants