-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Gemma capping #34282
base: main
Are you sure you want to change the base?
Gemma capping #34282
Conversation
b0ace40
to
9515b4d
Compare
5d7d66e
to
520120a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a lot of edge cases in imports which are very hard to deal with with the proposed approach. I think a simpler and more general approach is to do it the other way around:
- dump all imports from the
modular_xxx.py
as is - dump all imports from the dependency files as is (this is currently the case)
- Then, in the
PostModularConverterCleaner
, clean the imports (may even only clean the protected imports, and letruff
remove the other unused, non-protected imports)
This approach is much easier and versatile because in the Cleaner, we have access to the final source code, which is not the case when visiting the modular_xxx.py
file (we only see the modular + the dependencies, and it is hard to check imports relative to only the part of the dependency files that we copy in the final file). Thus, it would ensure that all needed imports are present (i.e. we will never reach a weird edge-case when trying to match the imports as we do currently), and we can correctly remove imports that were wrongly added from the dependency files (i.e. see duplicate import in Glm due to Phi3 dependency).
This would greatly simplify the code complexity as well in my opinion.
|
||
attn_output = torch.nn.functional.scaled_dot_product_attention( | ||
attn_output = flex_attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it a bit misleading to use flex attn when we have attn_implementation="sdpa"
? My concerns would be
- People that previously used sdpa (forced or not) will suddenly have different torch requirements
- Sdpa != Flexattn imo, it's a different API, name, and potentially slightly different behaviour
- Are the slow tests still passing? We should ensure that it's still behaving the same ish in comparison to eager
Wdyt about making another attn implementation option for flex attn specifically? Not sure if this goes over the goal but control over the specific implementation is always appreciated.
Overall excited to see this, great work!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SDPA version of gemma
never "worked" TBH!
I'll probably add a new class for flex attention, this was simpler for testing
Okay @Cyrilvallez good point regarding cleaning! Makes more sense indeed, will update to fix 😉 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice approach! Much simpler IMO 🤗 just added some nits for clarity
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
…nto gemma-capping
…haviour (for our tests as well :))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I actually love it, I think it's much better to use different attention functions instead of different attention classes (clearer, less duplicated code, and we can easily switch between implementations even after the model has been instantiated)
What does this PR do?
Adds capping for gemma2, fixes #32877