forked from Dao-AILab/flash-attention
-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support page kvcache in AMD ROCm (Dao-AILab#1198)
* Integrate ck branch of ck_tile/fa_bwd_opt * Assume dq and q share the same stride * update ck * Integrate more stride of dq_acc * Revert fwd dropout * Fix paremeter order * Integrate ck with more stride * update the limit of hdim of bwd * Check argument * Add test_flash_attn_causal * Support unpad lse * Add test_flash_attn_varlen_causal, test_flash_attn_race_condition, test_flash_attn_bwd_overflow, test_flash_attn_bwd_transpose, test_flash_attn_bwd_varlen_overflow, test_flash_attn_deterministic, test_flash_attn_varlen_deterministic * Fix stride and Kn0 * Fix CK sync issue * Fix typo * Update CK for changing of fmha_fwd_args * Add kvcache tmp * Add kvcache * Fix comment * Sync behavior with ck * Update CK to develop * remove large test case * Add kvcache test * Fix page_block_size in arg * Minor fix * Fix stride error * Update seqlen of kvcache before splitkv * Fix compile error * Fix bug of hdim is not 8x * Fit ck arg * support adaptive num_splits * add more tests * Refine test tolerance * update CK * Move override_num_splits_if_necessary into cpp * update ck * Update ck * Support different flag for different version of hip * remove coerce-illegal, becasue this is not required in FA * Update ck to fix xcratch memory * Add coerce-illegal in some version * Add compile flag for rtn rounding * remove redundant init * Using env var to switch rounding mode * update ck
- Loading branch information
1 parent
cc1690d
commit e2182cc
Showing
11 changed files
with
1,749 additions
and
131 deletions.
There are no files selected for viewing
Submodule composable_kernel
updated
386 files
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
/****************************************************************************** | ||
* Copyright (c) 2024, Tri Dao. | ||
******************************************************************************/ | ||
|
||
#include "flash_common.hpp" | ||
|
||
namespace flash { | ||
int override_num_splits_if_necessary(int batch, int nhead, int max_seqlen_q, int hdim_v, float p_drop, int num_splits) | ||
{ | ||
int device; | ||
auto status = hipGetDevice(&device); | ||
if(status != hipSuccess) | ||
return num_splits; | ||
|
||
hipDeviceProp_t props{}; | ||
status = hipGetDeviceProperties(&props, device); | ||
if(status != hipSuccess) | ||
return num_splits; | ||
|
||
// TODO - tile size should match the TileFmhaShape, hardcode for now | ||
const int kM0 = 128; | ||
const int kN1 = hdim_v; | ||
|
||
const int num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0; | ||
const int num_n_blocks = (hdim_v + kN1 - 1) / kN1; | ||
|
||
if(num_splits < 1 && p_drop == 0.0f) | ||
return num_splits_heuristic_ck( | ||
batch * nhead * num_m_blocks, props.multiProcessorCount * 2, num_n_blocks, 128); | ||
|
||
return num_splits; | ||
} | ||
|
||
} // namespace flash |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.