-
Notifications
You must be signed in to change notification settings - Fork 26.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
flash-attention-3 #33522
base: main
Are you sure you want to change the base?
flash-attention-3 #33522
Conversation
b6afd63
to
0976545
Compare
|
5aa58ab
to
7ae105e
Compare
All models supporting FAv2 should now have FAv3 classes.
|
7ae105e
to
bd6e9e7
Compare
All occurrences of Documentation and tests will be done next. |
bd6e9e7
to
fbf9bec
Compare
Some documentation and all tests are updated for FAv3. I'll run the tests on a H100 instance then mark this as ready for (initial) review. |
fbf9bec
to
27edb62
Compare
Generally FAv3 tests are failing due to the small configurations used: Instead I've tested the majority of models from their examples, with a few exceptions like Gemma and Mistral that I need to request access to, and particularly large models such as Jamba that my instance doesn't have space to download. All of the tested models with examples are ok, with the exception of
However this error also occurs with StableLM models are currently not supported due to num_attention_heads/ I've attached test reports, the numerical accuracy failures may need special care as per |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wowowo super nice initiative thanks! 🔥
IMO since we already abstracted the flash attention API, let's try to keep it in flashAttentionLlama
but maybe support flash_attention_3
in the attn_implementation
for example! WDYT?
value_states = value_states.to(target_dtype) | ||
|
||
# TODO: get `use_fp8` to here, add attention_kwargs or something | ||
attn_output = _flash_attention_3_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! As far as I can tell, the only diff is the forward function right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah the difference between FlashAttention2 classes and FlashAttention3 is just the forward function, and lack of dropout/sliding window/softcap for FAv3. As you suggest we could support v3 in the existing classes instead using config.attn_implementation
to select the appropriate function, happy to make this change if you think that's better.
return (query, key, value, indices_q, (cu_seq_lens, cu_seq_lens), (max_length, max_length)) | ||
|
||
|
||
def _flash_attention_3_forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's maybe replace flash_attention_forward by this one when flash attention3 is available WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAIK FAv3 will be for Hopper GPUs only
27edb62
to
ba268ef
Compare
I've replaced the I've renamed the Note that while I was checking all In We could simplify the changes to Checks like |
4473129
to
4a34da8
Compare
Sliding window is now supported. |
Very great work 🚀 just a passerby who looked into the code :)
I'd be very pro this. It kinda looks misleading now with Personal preference: I'd go a step further and move the
Seems reasonable to me. Makes the code less verbose too. Lastly, Edit: Maybe raising a value error / warning if dropout or similar values are passed, would also be nice since now it's just silently ignoring them. |
4a34da8
to
9bcbe3f
Compare
I've removed I'll wait for input from a maintainer on changing the checks (
I assume this won't be merged until FAv3 is out of beta, at which point dropout and softcap should hopefully be supported, if not then I agree we should add an error/warning if they're used with FAv3. |
9bcbe3f
to
2146a74
Compare
What does this PR do?
This PR adds preliminary support for Flash Attention 3.
is_flash_attn_3_available
required a workaround in_is_package_available
aspackage_version = importlib.metadata.version(pkg_name)
fails withimportlib.metadata.PackageNotFoundError: No package metadata was found for flash_attn_interface
._supports_flash_attn_3
and_check_and_enable_flash_attn_3
added tomodeling_utils.py
, near duplicate of_check_and_enable_flash_attn_2
._flash_attention_3_forward
implemented inmodeling_flash_attention_3_utils.py
_flash_attention_forward
is now a unified interface for FAv2 and FAv3 controlled byuse_flash_attn_3
which is passed fromFlashAttention
classes based onconfig._attn_implementation == "flash_attention_3"
.sliding window(edit: sliding window is now supported) or softcap, and in FAv3flash_attn_func
/flash_attn_varlen_func
return a tuple.attention_mask is not None
andposition_ids is not None
paths depend on_upad_input
andprepare_fa2_from_position_ids
respectively,these are duplicated fromand are not included in FAv3 package therefore FAv3 depends onmodeling_flash_attention_utils.py
flash_attn
, this is reflected inis_flash_attn_3_available
which checks foris_flash_attn_2_available
.FLASH_ATTENTION_3_FP8
for this purpose, we can probably add something likeattention_kwargs
to model forwards to control this, or maybe another_attn_implementation
typeflash_attention_3_fp8
, best to get reviews first and consensus on the best way to do it[1]Edit: added to other models, see comment below.flash_attention_3
is added to Llama withLlamaFlashAttention3
, similar toLlamaFlashAttention2
with unsupported options like dropout and sliding window removed.See comment below._update_causal_mask
is updated in various models due toutils/check_copies.py
, and_supports_flash_attn_3
is added in to some other models already for the same reason.Fixes #33373
Todo
Testattention_mask is not None
andposition_ids is not None
pathsImplement FlashAttention3 classes for other modelsDone.DocumentationPartly done.Notes
Llama tested on H100 SXM with:
(shortened) responses
FP16:
FP8:
All other models will be tested after I've finished adding FlashAttention3 classes.Other models have been tested, see comment below.Who can review?
cc @ArthurZucker