Skip to content

Commit

Permalink
Add FlashAttention3 to CP implementations (#1232)
Browse files Browse the repository at this point in the history
* fa2 function import renaming

Signed-off-by: Xiaowei Ren <[email protected]>

* refine fa_fwd_kwargs and fa_bwd_kwargs

Signed-off-by: Xiaowei Ren <[email protected]>

* import FA3 fucntions for CP

Signed-off-by: Xiaowei Ren <[email protected]>

* fix output of FA3 fwd

Signed-off-by: Xiaowei Ren <[email protected]>

* fix rng_state in a2a implementation with FA3

Signed-off-by: Xiaowei Ren <[email protected]>

* hack lse correction for packed lse format

Signed-off-by: Xiaowei Ren <[email protected]>

* make CP thd out correction work with packed lse

Signed-off-by: Xiaowei Ren <[email protected]>

* fix for packed softmax_lse

Signed-off-by: Xiaowei Ren <[email protected]>

* fix softmax_lse shape

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change lse_packed to constexpr

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Xiaowei Ren <[email protected]>
Signed-off-by: Charlene Yang <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Charlene Yang <[email protected]>
  • Loading branch information
3 people authored Oct 11, 2024
1 parent 9ee2dbd commit b36bd0a
Show file tree
Hide file tree
Showing 3 changed files with 353 additions and 246 deletions.
Loading

0 comments on commit b36bd0a

Please sign in to comment.