diff --git a/csrc/rocm/attention.cu b/csrc/rocm/attention.cu index 4d5beb922b285..efda714f53c6c 100644 --- a/csrc/rocm/attention.cu +++ b/csrc/rocm/attention.cu @@ -316,12 +316,14 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( int vphysical_blocks[VBLOCKS]; const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; + if constexpr (GQA_RATIO < 12) { #pragma unroll - for (int b = 0; b < VBLOCKS; b++) { - const int vblock_idx = warp_start_block_idx + b; - const int vblock_idx_ctx = - (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; - vphysical_blocks[b] = block_table[vblock_idx_ctx]; + for (int b = 0; b < VBLOCKS; b++) { + const int vblock_idx = warp_start_block_idx + b; + const int vblock_idx_ctx = + (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; + vphysical_blocks[b] = block_table[vblock_idx_ctx]; + } } // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems @@ -379,6 +381,17 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( } } + // fetch vphysical block numbers up front + if constexpr (GQA_RATIO >= 12) { + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + const int vblock_idx = warp_start_block_idx + b; + const int vblock_idx_ctx = + (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; + vphysical_blocks[b] = block_table[vblock_idx_ctx]; + } + } + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); @@ -1023,8 +1036,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( context_lens_ptr, max_num_partitions, fp8_out_scale_ptr); template + int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE> void paged_attention_custom_launcher( torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, @@ -1069,7 +1081,6 @@ void paged_attention_custom_launcher( const int gqa_ratio = num_heads / num_kv_heads; assert(num_heads % num_kv_heads == 0); assert(head_size == HEAD_SIZE); - assert(max_num_partitions <= 256); constexpr int NTHR = PARTITION_SIZE; dim3 grid(num_seqs, max_num_partitions, num_kv_heads); @@ -1129,9 +1140,6 @@ void paged_attention_custom_launcher( TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); break; } - // dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); - // dim3 block2(1024); - // LAUNCH_CUSTOM_ATTENTION2; // reduction kernel is only required if max_context_len > partition size, // otherwise main kernel writes directly to final output @@ -1142,6 +1150,7 @@ void paged_attention_custom_launcher( dim3 reduce_grid(num_heads, num_seqs); dim3 reduce_block(head_size); const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE); + // support upto 8*64*256=128K context length switch (npar_loops) { case 1: LAUNCH_CUSTOM_REDUCTION(1); @@ -1155,6 +1164,18 @@ void paged_attention_custom_launcher( case 4: LAUNCH_CUSTOM_REDUCTION(4); break; + case 5: + LAUNCH_CUSTOM_REDUCTION(5); + break; + case 6: + LAUNCH_CUSTOM_REDUCTION(6); + break; + case 7: + LAUNCH_CUSTOM_REDUCTION(7); + break; + case 8: + LAUNCH_CUSTOM_REDUCTION(8); + break; default: TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops); break; @@ -1162,20 +1183,44 @@ void paged_attention_custom_launcher( } } -#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT) \ - paged_attention_custom_launcher( \ - out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ - num_kv_heads, scale, block_tables, context_lens, max_context_len, \ +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \ + PSIZE) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ alibi_slopes, k_scale, v_scale, fp8_out_scale); -#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ - if (fp8_out_scale) { \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, uint8_t); \ - } else { \ - CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \ +#define CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + OUTT) \ + switch (partition_size) { \ + case 256: \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); \ + break; \ + case 512: \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 512); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported partition size: ", partition_size); \ + break; \ } +#if defined(__HIPCC__) && defined(__gfx90a__) + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + if (fp8_out_scale) { \ + TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \ + } else { \ + CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \ + } +#else + #define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + if (fp8_out_scale) { \ + CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \ + uint8_t); \ + } else { \ + CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \ + } +#endif #define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ switch (block_size) { \ case 16: \ @@ -1219,7 +1264,7 @@ void paged_attention( int64_t block_size, int64_t max_context_len, const c10::optional& alibi_slopes, const std::string& kv_cache_dtype, double k_scale, double v_scale, - const c10::optional& fp8_out_scale) { + const c10::optional& fp8_out_scale, int64_t partition_size) { const int head_size = query.size(2); if (kv_cache_dtype == "auto") { if (query.dtype() == at::ScalarType::Half) { diff --git a/csrc/rocm/ops.h b/csrc/rocm/ops.h index 9549cfa5dae85..d825686a6ced4 100644 --- a/csrc/rocm/ops.h +++ b/csrc/rocm/ops.h @@ -11,14 +11,12 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, const int64_t N_in, const int64_t CuCount); -void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums, - torch::Tensor& max_logits, torch::Tensor& tmp_out, - torch::Tensor& query, torch::Tensor& key_cache, - torch::Tensor& value_cache, int64_t num_kv_heads, - double scale, torch::Tensor& block_tables, - torch::Tensor& context_lens, int64_t block_size, - int64_t max_context_len, - const c10::optional& alibi_slopes, - const std::string& kv_cache_dtype, double k_scale, - double v_scale, - const c10::optional& fp8_out_scale); +void paged_attention( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + int64_t block_size, int64_t max_context_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const c10::optional& fp8_out_scale, int64_t partition_size); diff --git a/csrc/rocm/torch_bindings.cpp b/csrc/rocm/torch_bindings.cpp index 4d21ea944ee41..6402a3b2b2b60 100644 --- a/csrc/rocm/torch_bindings.cpp +++ b/csrc/rocm/torch_bindings.cpp @@ -36,7 +36,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) { " Tensor? alibi_slopes," " str kv_cache_dtype," " float k_scale, float v_scale," - " Tensor? fp8_out_scale) -> ()"); + " Tensor? fp8_out_scale," + " int partition_size) -> ()"); rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention); rocm_ops.def( "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index c6e5bed5ad9a3..619973e0282b2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -180,13 +180,14 @@ def paged_attention_rocm( k_scale: float, v_scale: float, fp8_out_scale: Optional[torch.Tensor], + partition_size: int, ) -> None: torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, - fp8_out_scale) + fp8_out_scale, partition_size) # pos encoding ops diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 660f26c71c381..9407194cf3fc5 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -20,7 +20,7 @@ logger = init_logger(__name__) -_PARTITION_SIZE_ROCM = 512 +_PARTITION_SIZE_ROCM = 256 _ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName @@ -798,6 +798,7 @@ def forward( k_scale, v_scale, fp8_out_scale if cpa_fp8_out else None, + _PARTITION_SIZE_ROCM, ) if cpa_fp8_out: return out.view(num_seqs, num_heads * head_size)