Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemma capping #34282

Open
wants to merge 49 commits into
base: main
Choose a base branch
from
Open

Gemma capping #34282

wants to merge 49 commits into from

Conversation

ArthurZucker
Copy link
Collaborator

@ArthurZucker ArthurZucker commented Oct 21, 2024

What does this PR do?

Adds capping for gemma2, fixes #32877

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of edge cases in imports which are very hard to deal with with the proposed approach. I think a simpler and more general approach is to do it the other way around:

  • dump all imports from the modular_xxx.py as is
  • dump all imports from the dependency files as is (this is currently the case)
  • Then, in the PostModularConverterCleaner, clean the imports (may even only clean the protected imports, and let ruff remove the other unused, non-protected imports)

This approach is much easier and versatile because in the Cleaner, we have access to the final source code, which is not the case when visiting the modular_xxx.py file (we only see the modular + the dependencies, and it is hard to check imports relative to only the part of the dependency files that we copy in the final file). Thus, it would ensure that all needed imports are present (i.e. we will never reach a weird edge-case when trying to match the imports as we do currently), and we can correctly remove imports that were wrongly added from the dependency files (i.e. see duplicate import in Glm due to Phi3 dependency).
This would greatly simplify the code complexity as well in my opinion.

utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved

attn_output = torch.nn.functional.scaled_dot_product_attention(
attn_output = flex_attention(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it a bit misleading to use flex attn when we have attn_implementation="sdpa"? My concerns would be

  • People that previously used sdpa (forced or not) will suddenly have different torch requirements
  • Sdpa != Flexattn imo, it's a different API, name, and potentially slightly different behaviour
  • Are the slow tests still passing? We should ensure that it's still behaving the same ish in comparison to eager

Wdyt about making another attn implementation option for flex attn specifically? Not sure if this goes over the goal but control over the specific implementation is always appreciated.

Overall excited to see this, great work!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SDPA version of gemma never "worked" TBH!
I'll probably add a new class for flex attention, this was simpler for testing

@ArthurZucker
Copy link
Collaborator Author

Okay @Cyrilvallez good point regarding cleaning! Makes more sense indeed, will update to fix 😉

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice approach! Much simpler IMO 🤗 just added some nits for clarity

utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
utils/modular_model_converter.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@Cyrilvallez Cyrilvallez left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I actually love it, I think it's much better to use different attention functions instead of different attention classes (clearer, less duplicated code, and we can easily switch between implementations even after the model has been instantiated)

src/transformers/models/gemma2/modular_gemma2.py Outdated Show resolved Hide resolved
src/transformers/models/gemma2/modular_gemma2.py Outdated Show resolved Hide resolved
src/transformers/models/gemma2/modular_gemma2.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add logit scaling sdpa using FlexAttention for Gemma2
5 participants