From af5c1394cc9f8ed4be85935e76b625136325d1b9 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Tue, 2 Jun 2026 12:35:22 +0800 Subject: [PATCH 01/29] remove unsed para in produce_kv_blockwise --- .../append_attn/append_attention_func.cuh | 3 --- .../multiquery_attention_c16_impl.cuh | 24 ------------------- 2 files changed, 27 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index ed40c5e8adb..2b2f90f0ac9 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -274,9 +274,6 @@ __device__ __forceinline__ void produce_kv_blockwise( smem_t smem, uint32_t* smem_offset, T** gptr, // [max_block_num, num_heads, block_size, head_dim] - const uint32_t kv_head_idx, - const uint32_t kv_n_stride, - const uint32_t kv_h_stride, const uint32_t kv_b_stride, const uint32_t kv_idx_base, const uint32_t kv_len) { diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index c82ba20a4c6..c2d8d9cd040 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -221,9 +221,6 @@ __global__ void multi_query_append_attention_kernel( NUM_WARP_Q>(k_smem, &kv_smem_offset_w, &cache_k_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, kv_b_stride, kv_idx_base, chunk_end); @@ -236,9 +233,6 @@ __global__ void multi_query_append_attention_kernel( NUM_WARP_Q>(v_smem, &kv_smem_offset_w, &cache_v_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, kv_b_stride, kv_idx_base, chunk_end); @@ -286,9 +280,6 @@ __global__ void multi_query_append_attention_kernel( NUM_WARP_Q>(k_smem, &kv_smem_offset_w, &cache_k_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, kv_b_stride, kv_idx_base, chunk_end); @@ -310,9 +301,6 @@ __global__ void multi_query_append_attention_kernel( NUM_WARP_Q>(v_smem, &kv_smem_offset_w, &cache_v_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, kv_b_stride, kv_idx_base, chunk_end); @@ -597,9 +585,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel( NUM_WARP_Q>(k_smem, &kv_smem_offset_w, &cache_k_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, kv_b_stride, kv_idx_base, chunk_end); @@ -613,9 +598,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel( NUM_WARP_Q>(v_smem, &kv_smem_offset_w, &cache_v_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, kv_b_stride, kv_idx_base, chunk_end); @@ -665,9 +647,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel( NUM_WARP_Q>(k_smem, &kv_smem_offset_w, &cache_k_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, kv_b_stride, kv_idx_base, chunk_end); @@ -689,9 +668,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel( NUM_WARP_Q>(v_smem, &kv_smem_offset_w, &cache_v_now, - kv_head_idx, - kv_n_stride, - kv_h_stride, kv_b_stride, kv_idx_base, chunk_end); From 0447b7a16ea4e5b9c11f38815e005f85acdd7016 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Tue, 2 Jun 2026 14:47:53 +0800 Subject: [PATCH 02/29] remove unsed para in produce_kv_blockwise --- .../append_attn/append_attention_func.cuh | 2 +- .../multiquery_attention_c16_impl.cuh | 176 +++++++++--------- 2 files changed, 89 insertions(+), 89 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 2b2f90f0ac9..a54af15e748 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -270,7 +270,7 @@ template -__device__ __forceinline__ void produce_kv_blockwise( +__device__ __forceinline__ void produce_kv_blockwise_c16( smem_t smem, uint32_t* smem_offset, T** gptr, // [max_block_num, num_heads, block_size, head_dim] diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index c2d8d9cd040..f074c12f796 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -213,29 +213,29 @@ __global__ void multi_query_append_attention_kernel( const T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; const T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; - produce_kv_blockwise(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_b_stride, - kv_idx_base, - chunk_end); + produce_kv_blockwise_c16(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); - produce_kv_blockwise(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_b_stride, - kv_idx_base, - chunk_end); + produce_kv_blockwise_c16(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { @@ -272,17 +272,17 @@ __global__ void multi_query_append_attention_kernel( block_id = 0; } cache_k_now = cache_k + block_id * kv_n_stride + const_offset; - produce_kv_blockwise(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_b_stride, - kv_idx_base, - chunk_end); + produce_kv_blockwise_c16(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); wait_group<1>(); __syncthreads(); @@ -293,17 +293,17 @@ __global__ void multi_query_append_attention_kernel( __syncthreads(); cache_v_now = cache_v + block_id * kv_n_stride + const_offset; - produce_kv_blockwise(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_b_stride, - kv_idx_base, - chunk_end); + produce_kv_blockwise_c16(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); } wait_group<0>(); @@ -577,30 +577,30 @@ __global__ void multi_query_append_attention_warp1_4_kernel( T *cache_k_now = cache_k + block_id * kv_n_stride + const_offset; T *cache_v_now = cache_v + block_id * kv_n_stride + const_offset; - produce_kv_blockwise(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_b_stride, - kv_idx_base, - chunk_end); + produce_kv_blockwise_c16(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); - produce_kv_blockwise(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_b_stride, - kv_idx_base, - chunk_end); + produce_kv_blockwise_c16(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); #pragma unroll 1 @@ -639,17 +639,17 @@ __global__ void multi_query_append_attention_warp1_4_kernel( block_id = 0; } cache_k_now = cache_k + block_id * kv_n_stride + const_offset; - produce_kv_blockwise(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_b_stride, - kv_idx_base, - chunk_end); + produce_kv_blockwise_c16(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); wait_group<1>(); __syncthreads(); @@ -660,17 +660,17 @@ __global__ void multi_query_append_attention_warp1_4_kernel( __syncthreads(); cache_v_now = cache_v + block_id * kv_n_stride + const_offset; - produce_kv_blockwise(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_b_stride, - kv_idx_base, - chunk_end); + produce_kv_blockwise_c16(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); } wait_group<0>(); From ed50b12b612ea2249c93d297f4dd9ac8a611f4ba Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Tue, 2 Jun 2026 17:08:20 +0800 Subject: [PATCH 03/29] remove unsed para in produce_kv_blockwise --- .../gpu_ops/append_attn/append_attention_func.cuh | 15 ++++++--------- .../append_attn/multiquery_attention_c16_impl.cuh | 12 +----------- .../append_attn/multiquery_attention_c4_impl.cuh | 6 +----- .../append_attn/multiquery_attention_c8_impl.cuh | 13 ++----------- 4 files changed, 10 insertions(+), 36 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index a54af15e748..26bfb5895f8 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -2276,11 +2276,9 @@ __global__ void merge_multi_chunks_decoder_kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_seq_len, const int num_chunks, const int num_heads, - const int chunk_size, - const int head_dim) { + const int chunk_size) { const int vid = threadIdx.x, ty = threadIdx.y; const int bid = blockIdx.x, hid = blockIdx.y; __shared__ T smem[bdy * HEAD_DIM]; @@ -2336,7 +2334,7 @@ __global__ void merge_multi_chunks_decoder_kernel( const float m_now = multi_m[offset]; const float d_now = multi_d[offset]; m = max(m_prev, m_now); - offset = offset * head_dim + vid * vec_size; + offset = offset * HEAD_DIM + vid * vec_size; Load(&multi_out[offset], &load_vec); const float scale1 = __expf(m_prev - m), scale2 = __expf(m_now - m); const T scale1_T = static_cast(scale1), @@ -2348,7 +2346,7 @@ __global__ void merge_multi_chunks_decoder_kernel( } } // store ty res - Store(res_vec, &smem[ty * head_dim + vid * vec_size]); + Store(res_vec, &smem[ty * HEAD_DIM + vid * vec_size]); md_smem[2 * ty] = m; md_smem[2 * ty + 1] = d; __syncthreads(); @@ -2358,7 +2356,7 @@ __global__ void merge_multi_chunks_decoder_kernel( st.init(); #pragma unroll for (int i = 0; i < bdy; i++) { - Load(&smem[i * head_dim + vid * vec_size], &load_vec); + Load(&smem[i * HEAD_DIM + vid * vec_size], &load_vec); const float m_tmp = md_smem[2 * i], d_tmp = md_smem[2 * i + 1]; st.merge(load_vec, m_tmp, d_tmp); } @@ -2369,7 +2367,7 @@ __global__ void merge_multi_chunks_decoder_kernel( st.normalize(); } - const uint32_t shift_smooth_offset = hid * head_dim + vid * vec_size; + const uint32_t shift_smooth_offset = hid * HEAD_DIM + vid * vec_size; AlignedVector shift_bias_vec; AlignedVector smooth_weight_vec; AlignedVector out_vec; @@ -2391,7 +2389,7 @@ __global__ void merge_multi_chunks_decoder_kernel( } Store( out_vec, - &out[(start_token_idx * num_heads + hid) * head_dim + vid * vec_size]); + &out[(start_token_idx * num_heads + hid) * HEAD_DIM + vid * vec_size]); } #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); @@ -2422,7 +2420,6 @@ __global__ void merge_multi_chunks_v2_kernel( const float quant_max_bound, const float quant_min_bound, const float in_scale, - const int max_seq_len, const int num_chunks, const int num_heads, const int chunk_size, diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index f074c12f796..0b9eefec482 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -45,7 +45,6 @@ __global__ void multi_query_append_attention_kernel( const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ mask_offset, - const int max_seq_len, const int max_block_num_per_seq, const float scale, const float quant_max_bound, @@ -430,7 +429,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ mask_offset, const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask - const int max_seq_len, const int max_block_num_per_seq, const float scale, const float quant_max_bound, @@ -923,7 +921,6 @@ void MultiQueryAppendAttention( cu_seqlens_q.data(), block_table.data(), meta_data.mask_offset, - max_seq_len, max_block_num_per_seq, scale, quant_max_bound, @@ -991,7 +988,6 @@ void MultiQueryAppendAttention( cu_seqlens_q.data(), block_table.data(), meta_data.mask_offset, - max_seq_len, max_block_num_per_seq, scale, quant_max_bound, @@ -1047,7 +1043,6 @@ void MultiQueryAppendAttention( quant_max_bound, quant_min_bound, in_scale, - max_seq_len, num_chunks, num_heads, chunk_size, @@ -1152,7 +1147,6 @@ void MultiQueryAppendAttention( meta_data.mask_offset, attn_mask ? const_cast(attn_mask.get().data()) : nullptr, - max_seq_len, max_block_num_per_seq, scale, quant_max_bound, @@ -1210,7 +1204,6 @@ void MultiQueryAppendAttention( meta_data.mask_offset, attn_mask ? const_cast(attn_mask.get().data()) : nullptr, - max_seq_len, max_block_num_per_seq, scale, quant_max_bound, @@ -1266,11 +1259,9 @@ void MultiQueryAppendAttention( quant_max_bound, quant_min_bound, in_scale, - max_seq_len, num_chunks, num_heads, - chunk_size, - HEAD_DIM); + chunk_size); } else { constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; @@ -1310,7 +1301,6 @@ void MultiQueryAppendAttention( quant_max_bound, quant_min_bound, in_scale, - max_seq_len, num_chunks, num_heads, chunk_size, diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh index e9465e84cb5..8b9bade7ee0 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh @@ -1298,7 +1298,6 @@ void MultiQueryAppendC4Attention( quant_max_bound, quant_min_bound, in_scale, - max_seq_len, num_chunks, num_heads, chunk_size, @@ -1563,11 +1562,9 @@ void MultiQueryAppendC4Attention( quant_max_bound, quant_min_bound, in_scale, - max_seq_len, num_chunks, num_heads, - chunk_size, - HEAD_DIM); + chunk_size); } else { constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; @@ -1606,7 +1603,6 @@ void MultiQueryAppendC4Attention( quant_max_bound, quant_min_bound, in_scale, - max_seq_len, num_chunks, num_heads, chunk_size, diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh index 2376c317051..9bc8651ea3a 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh @@ -52,7 +52,6 @@ __global__ void multi_query_append_attention_c8_kernel( const int *__restrict__ cu_seqlens_q, const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ mask_offset, - const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, const float scale, @@ -587,7 +586,6 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( const int *__restrict__ block_table, // [bsz, block_num_per_seq] const int *__restrict__ mask_offset, const bool *__restrict__ attn_mask, // [bsz, max_q, max_q] for tree-mask - const int max_seq_len, const int max_dec_len, const int max_block_num_per_seq, const float scale, @@ -1294,7 +1292,6 @@ void MultiQueryAppendC8Attention( cu_seqlens_q.data(), block_table.data(), meta_data.mask_offset, - max_seq_len, max_dec_len, max_block_num_per_seq, scale, @@ -1363,7 +1360,6 @@ void MultiQueryAppendC8Attention( cu_seqlens_q.data(), block_table.data(), meta_data.mask_offset, - max_seq_len, max_dec_len, max_block_num_per_seq, scale, @@ -1418,7 +1414,6 @@ void MultiQueryAppendC8Attention( quant_max_bound, quant_min_bound, in_scale, - max_seq_len, num_chunks, num_heads, chunk_size, @@ -1568,7 +1563,6 @@ void MultiQueryAppendC8Attention( meta_data.mask_offset, attn_mask ? const_cast(attn_mask.get().data()) : nullptr, - max_seq_len, max_dec_len, max_block_num_per_seq, scale, @@ -1654,7 +1648,7 @@ void MultiQueryAppendC8Attention( meta_data.mask_offset, attn_mask ? const_cast(attn_mask.get().data()) : nullptr, - max_seq_len, + max_dec_len, max_block_num_per_seq, scale, @@ -1710,11 +1704,9 @@ void MultiQueryAppendC8Attention( quant_max_bound, quant_min_bound, in_scale, - max_seq_len, num_chunks, num_heads, - chunk_size, - HEAD_DIM); + chunk_size); } else { constexpr int blockx = HEAD_DIM / vec_size; constexpr int blocky = (128 + blockx - 1) / blockx; @@ -1753,7 +1745,6 @@ void MultiQueryAppendC8Attention( quant_max_bound, quant_min_bound, in_scale, - max_seq_len, num_chunks, num_heads, chunk_size, From 635b4f36c2ca453130c2fe2381e3d3793be98ed0 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Tue, 2 Jun 2026 22:47:11 +0800 Subject: [PATCH 04/29] simplify code in produce_kv_blockwise_c16 --- .../append_attn/append_attention_func.cuh | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 26bfb5895f8..3c1e64d1c04 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -279,13 +279,14 @@ __device__ __forceinline__ void produce_kv_blockwise_c16( const uint32_t kv_len) { constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); - constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; const uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check + static_assert(block_size % (4 * num_warps) == 0, ""); + static_assert(head_dim % 64 == 0, ""); #pragma unroll - for (uint32_t i = 0; i < NUM_WARP_KV * num_frags_z * 4 / num_warps; ++i) { + for (uint32_t i = 0; i < block_size / (num_warps * 4); ++i) { #pragma unroll - for (uint32_t j = 0; j < num_frags_y / 4; ++j) { + for (uint32_t j = 0; j < head_dim / 64; ++j) { smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = smem.advance_offset_by_column<8>(*smem_offset, j); *gptr += 8 * num_elems_per_128b(); @@ -293,12 +294,11 @@ __device__ __forceinline__ void produce_kv_blockwise_c16( kv_idx += num_warps * 4; *smem_offset = smem.advance_offset_by_row( *smem_offset) - - 2 * num_frags_y; // num_frags_y / 4 * 8 - *gptr += - num_warps * 4 * kv_b_stride - 2 * num_frags_y * num_elems_per_128b(); + head_dim / 8; + *gptr += num_warps * 4 * kv_b_stride - head_dim; } - *gptr -= NUM_WARP_KV * num_frags_z * 16 * kv_b_stride; - *smem_offset -= NUM_WARP_KV * num_frags_z * 16 * num_vecs_per_head; + *gptr -= block_size * kv_b_stride; + *smem_offset -= block_size * num_vecs_per_head; } template Date: Tue, 2 Jun 2026 22:52:39 +0800 Subject: [PATCH 05/29] simplify code in produce_kv_blockwise_c16 --- .../append_attn/append_attention_func.cuh | 5 +- .../multiquery_attention_c16_impl.cuh | 88 +++++-------------- 2 files changed, 25 insertions(+), 68 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 3c1e64d1c04..f61809003d1 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -266,18 +266,15 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale( template __device__ __forceinline__ void produce_kv_blockwise_c16( smem_t smem, uint32_t* smem_offset, T** gptr, // [max_block_num, num_heads, block_size, head_dim] - const uint32_t kv_b_stride, const uint32_t kv_idx_base, const uint32_t kv_len) { - constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); const uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 0b9eefec482..b3adc39a301 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -215,26 +215,16 @@ __global__ void multi_query_append_attention_kernel( produce_kv_blockwise_c16(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_b_stride, - kv_idx_base, - chunk_end); + HEAD_DIM, + NUM_WARP_Q>( + k_smem, &kv_smem_offset_w, &cache_k_now, kv_idx_base, chunk_end); commit_group(); produce_kv_blockwise_c16(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_b_stride, - kv_idx_base, - chunk_end); + HEAD_DIM, + NUM_WARP_Q>( + v_smem, &kv_smem_offset_w, &cache_v_now, kv_idx_base, chunk_end); commit_group(); #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { @@ -274,14 +264,9 @@ __global__ void multi_query_append_attention_kernel( produce_kv_blockwise_c16(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_b_stride, - kv_idx_base, - chunk_end); + HEAD_DIM, + NUM_WARP_Q>( + k_smem, &kv_smem_offset_w, &cache_k_now, kv_idx_base, chunk_end); commit_group(); wait_group<1>(); __syncthreads(); @@ -295,14 +280,9 @@ __global__ void multi_query_append_attention_kernel( produce_kv_blockwise_c16(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_b_stride, - kv_idx_base, - chunk_end); + HEAD_DIM, + NUM_WARP_Q>( + v_smem, &kv_smem_offset_w, &cache_v_now, kv_idx_base, chunk_end); commit_group(); } wait_group<0>(); @@ -578,27 +558,17 @@ __global__ void multi_query_append_attention_warp1_4_kernel( produce_kv_blockwise_c16(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_b_stride, - kv_idx_base, - chunk_end); + HEAD_DIM, + NUM_WARP_Q>( + k_smem, &kv_smem_offset_w, &cache_k_now, kv_idx_base, chunk_end); commit_group(); produce_kv_blockwise_c16(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_b_stride, - kv_idx_base, - chunk_end); + HEAD_DIM, + NUM_WARP_Q>( + v_smem, &kv_smem_offset_w, &cache_v_now, kv_idx_base, chunk_end); commit_group(); #pragma unroll 1 @@ -640,14 +610,9 @@ __global__ void multi_query_append_attention_warp1_4_kernel( produce_kv_blockwise_c16(k_smem, - &kv_smem_offset_w, - &cache_k_now, - kv_b_stride, - kv_idx_base, - chunk_end); + HEAD_DIM, + NUM_WARP_Q>( + k_smem, &kv_smem_offset_w, &cache_k_now, kv_idx_base, chunk_end); commit_group(); wait_group<1>(); __syncthreads(); @@ -661,14 +626,9 @@ __global__ void multi_query_append_attention_warp1_4_kernel( produce_kv_blockwise_c16(v_smem, - &kv_smem_offset_w, - &cache_v_now, - kv_b_stride, - kv_idx_base, - chunk_end); + HEAD_DIM, + NUM_WARP_Q>( + v_smem, &kv_smem_offset_w, &cache_v_now, kv_idx_base, chunk_end); commit_group(); } wait_group<0>(); From ced1aa7c87bed4bca8836e25a86b2e549ce6b00a Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Tue, 2 Jun 2026 23:00:14 +0800 Subject: [PATCH 06/29] simplify code in produce_kv_blockwise_c16 --- custom_ops/gpu_ops/append_attn/append_attention_func.cuh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index f61809003d1..6354b4bcc04 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -280,6 +280,8 @@ __device__ __forceinline__ void produce_kv_blockwise_c16( uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check static_assert(block_size % (4 * num_warps) == 0, ""); static_assert(head_dim % 64 == 0, ""); + const int32_t kv_b_stride = head_dim; + #pragma unroll for (uint32_t i = 0; i < block_size / (num_warps * 4); ++i) { #pragma unroll From d055b0bd827393b10a1d6fef2fd36f3888228e89 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Tue, 2 Jun 2026 23:23:52 +0800 Subject: [PATCH 07/29] simplify code in produce_kv_blockwise_c16 --- custom_ops/gpu_ops/append_attn/append_attention_func.cuh | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 6354b4bcc04..d43fd690bd8 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -280,7 +280,6 @@ __device__ __forceinline__ void produce_kv_blockwise_c16( uint32_t kv_idx = kv_idx_base + ty * 4 + tx / 8; // kv_idx used to check static_assert(block_size % (4 * num_warps) == 0, ""); static_assert(head_dim % 64 == 0, ""); - const int32_t kv_b_stride = head_dim; #pragma unroll for (uint32_t i = 0; i < block_size / (num_warps * 4); ++i) { @@ -294,9 +293,9 @@ __device__ __forceinline__ void produce_kv_blockwise_c16( *smem_offset = smem.advance_offset_by_row( *smem_offset) - head_dim / 8; - *gptr += num_warps * 4 * kv_b_stride - head_dim; + *gptr += num_warps * 4 * head_dim - head_dim; } - *gptr -= block_size * kv_b_stride; + *gptr -= block_size * head_dim; *smem_offset -= block_size * num_vecs_per_head; } From 0f81ab8e8d9e48d9801af77145add3a32f96e406 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Wed, 3 Jun 2026 12:30:45 +0800 Subject: [PATCH 08/29] simplify code in produce_kv_blockwise_c16 --- .../multiquery_attention_c16_impl.cuh | 68 ++++++++++++++----- 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index b3adc39a301..1db58747046 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -60,7 +60,6 @@ __global__ void multi_query_append_attention_kernel( const int speculate_max_draft_token_num = 5, const int sliding_window = 0, const int sink_size = 0) { - constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; const uint32_t kv_num_heads = gridDim.z; const uint32_t q_num_heads = kv_num_heads * GROUP_SIZE; @@ -153,6 +152,7 @@ __global__ void multi_query_append_attention_kernel( mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; smem_t qo_smem(smem); + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 @@ -216,15 +216,23 @@ __global__ void multi_query_append_attention_kernel( NUM_WARPS, BLOCK_SIZE, HEAD_DIM, - NUM_WARP_Q>( - k_smem, &kv_smem_offset_w, &cache_k_now, kv_idx_base, chunk_end); + NUM_WARP_Q>(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); produce_kv_blockwise_c16( - v_smem, &kv_smem_offset_w, &cache_v_now, kv_idx_base, chunk_end); + NUM_WARP_Q>(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { @@ -265,8 +273,12 @@ __global__ void multi_query_append_attention_kernel( NUM_WARPS, BLOCK_SIZE, HEAD_DIM, - NUM_WARP_Q>( - k_smem, &kv_smem_offset_w, &cache_k_now, kv_idx_base, chunk_end); + NUM_WARP_Q>(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); wait_group<1>(); __syncthreads(); @@ -281,8 +293,12 @@ __global__ void multi_query_append_attention_kernel( NUM_WARPS, BLOCK_SIZE, HEAD_DIM, - NUM_WARP_Q>( - v_smem, &kv_smem_offset_w, &cache_v_now, kv_idx_base, chunk_end); + NUM_WARP_Q>(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); } wait_group<0>(); @@ -425,7 +441,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const uint32_t attn_mask_len = -1, const int sliding_window = 0, const int sink_size = 0) { - constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); static_assert(NUM_WARP_Q == 1, "NUM_WARP_Q must be 1"); static_assert(NUM_WARP_KV == 4, "NUM_WARP_KV must be 4"); const uint32_t btid = blockIdx.x, kv_head_idx = blockIdx.z; @@ -494,6 +509,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; smem_t qo_smem(smem); + constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); uint32_t q_smem_offset_r = smem_t::get_permuted_offset( tid % 16, tid / 16); // 16 * 16 @@ -559,16 +575,24 @@ __global__ void multi_query_append_attention_warp1_4_kernel( NUM_WARPS, BLOCK_SIZE, HEAD_DIM, - NUM_WARP_Q>( - k_smem, &kv_smem_offset_w, &cache_k_now, kv_idx_base, chunk_end); + NUM_WARP_Q>(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); produce_kv_blockwise_c16( - v_smem, &kv_smem_offset_w, &cache_v_now, kv_idx_base, chunk_end); + NUM_WARP_Q>(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); #pragma unroll 1 @@ -611,8 +635,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel( NUM_WARPS, BLOCK_SIZE, HEAD_DIM, - NUM_WARP_Q>( - k_smem, &kv_smem_offset_w, &cache_k_now, kv_idx_base, chunk_end); + NUM_WARP_Q>(k_smem, + &kv_smem_offset_w, + &cache_k_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); wait_group<1>(); __syncthreads(); @@ -627,8 +655,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel( NUM_WARPS, BLOCK_SIZE, HEAD_DIM, - NUM_WARP_Q>( - v_smem, &kv_smem_offset_w, &cache_v_now, kv_idx_base, chunk_end); + NUM_WARP_Q>(v_smem, + &kv_smem_offset_w, + &cache_v_now, + kv_b_stride, + kv_idx_base, + chunk_end); commit_group(); } wait_group<0>(); From 6d27f9ec1028aefaeab76971715392e8314beae2 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Wed, 3 Jun 2026 13:03:02 +0800 Subject: [PATCH 09/29] simplify code in produce_kv_blockwise_c16 --- custom_ops/gpu_ops/append_attn/append_attention_func.cuh | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index d43fd690bd8..c97cd1be383 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -273,6 +273,7 @@ __device__ __forceinline__ void produce_kv_blockwise_c16( smem_t smem, uint32_t* smem_offset, T** gptr, // [max_block_num, num_heads, block_size, head_dim] + const uint32_t kv_b_stride, const uint32_t kv_idx_base, const uint32_t kv_len) { constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); @@ -293,9 +294,9 @@ __device__ __forceinline__ void produce_kv_blockwise_c16( *smem_offset = smem.advance_offset_by_row( *smem_offset) - head_dim / 8; - *gptr += num_warps * 4 * head_dim - head_dim; + *gptr += num_warps * 4 * kv_b_stride - head_dim; } - *gptr -= block_size * head_dim; + *gptr -= block_size * kv_b_stride; *smem_offset -= block_size * num_vecs_per_head; } From b1cdc850b316b53fe1caa83a0abe0320ac83eed5 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Wed, 3 Jun 2026 15:07:57 +0800 Subject: [PATCH 10/29] simplify code in produce_kv_blockwise_c16 --- .../append_attn/append_attention_func.cuh | 20 ++++++++----------- .../multiquery_attention_c16_impl.cuh | 19 ++++++++---------- .../multiquery_attention_c4_impl.cuh | 19 ++++++++---------- .../multiquery_attention_c8_impl.cuh | 19 ++++++++---------- 4 files changed, 32 insertions(+), 45 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index c97cd1be383..be6d43d5e47 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -122,7 +122,6 @@ __device__ __forceinline__ void init_states(float (*o_frag)[num_frags_y][8], template __device__ __forceinline__ void load_q_global_smem_multi_warps( @@ -134,23 +133,22 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps( const uint32_t qo_h_stride) { constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + static_assert(HEAD_DIM % 64 == 0, ""); + const uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t q_smem_offset_w = // [NUM_WARP_Q, num_frags_x, 16, head_dim] smem_t::get_permuted_offset(ty * 4 + tx / 8, tx % 8); // 4 * 64 - - const uint32_t tx_offset = tx / 8; + // 4 warps will load (fx * 16) * (HEAD_DIM) data! #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - const uint32_t base_offset = q_idx_base + fx * 16 + tx_offset; #pragma unroll - const int j = ty; - const uint32_t offset_now = base_offset + j * 4; + const uint32_t offset_now = q_idx_base + fx * 16 + ty * 4 + tx / 8; const uint32_t n_offset = offset_now / group_size; const uint32_t h_offset = offset_now % group_size; T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { + for (uint32_t fyo = 0; fyo < HEAD_DIM / 64; ++fyo) { q_smem->load_128b_async( q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); q_smem_offset_w = @@ -159,7 +157,7 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps( } q_smem_offset_w = q_smem->advance_offset_by_row<16, num_vecs_per_head>(q_smem_offset_w) - - 2 * num_frags_y; + HEAD_DIM / 8; } } @@ -209,16 +207,14 @@ __device__ __forceinline__ void load_q_global_smem( } } -template +template __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps( - smem_t* q_smem, // [num_frags_x * 16, num_frags_y * 16] + smem_t* q_smem, // [num_frags_x * 16, head_dim] const float sm_scale) { constexpr int vec_size = 16 / sizeof(T); using LoadT = AlignedVector; LoadT tmp_vec; const uint32_t tx = threadIdx.x, ty = threadIdx.y; - constexpr uint32_t head_dim = num_frags_y * 16; - constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); #pragma unroll for (uint32_t i = 0; i < num_frags_x * 16 * head_dim / 1024; ++i) { diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 1db58747046..8d4b3259d38 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -517,21 +517,18 @@ __global__ void multi_query_append_attention_warp1_4_kernel( cudaGridDependencySynchronize(); #endif - load_q_global_smem_multi_warps(q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_len, - q_ori_n_stride, - HEAD_DIM); + load_q_global_smem_multi_warps( + q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_len, + q_ori_n_stride, + HEAD_DIM); commit_group(); wait_group<0>(); __syncthreads(); - q_smem_inplace_multiply_sm_scale_multi_warps( + q_smem_inplace_multiply_sm_scale_multi_warps( &qo_smem, scale); static_assert(num_rows_per_block == num_frags_x * 16); diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh index 8b9bade7ee0..8d4111a0a29 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh @@ -686,21 +686,18 @@ __global__ void multi_query_append_attention_c4_warp1_4_kernel( uint32_t q_smem_offset_r = smem_t::get_permuted_offset(tid % 16, tid / 16); - load_q_global_smem_multi_warps(q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_end, - q_ori_n_stride, - HEAD_DIM); + load_q_global_smem_multi_warps( + q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); commit_group(); wait_group<0>(); __syncthreads(); - q_smem_inplace_multiply_sm_scale_multi_warps( + q_smem_inplace_multiply_sm_scale_multi_warps( &qo_smem, scale); T cache_k_scale_frag[num_frags_y][4]; diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh index 9bc8651ea3a..611b5d66435 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh @@ -735,21 +735,18 @@ __global__ void multi_query_append_attention_c8_warp1_4_kernel( uint32_t q_smem_offset_r = smem_t::get_permuted_offset( tid % 16, tid / 16); // 16 * 16 - load_q_global_smem_multi_warps(q_base_ptr, - &qo_smem, - q_base_seq_id_this_block, - q_end, - q_ori_n_stride, - HEAD_DIM); + load_q_global_smem_multi_warps( + q_base_ptr, + &qo_smem, + q_base_seq_id_this_block, + q_end, + q_ori_n_stride, + HEAD_DIM); commit_group(); wait_group<0>(); __syncthreads(); - q_smem_inplace_multiply_sm_scale_multi_warps( + q_smem_inplace_multiply_sm_scale_multi_warps( &qo_smem, scale); smem_t k_smem(smem + num_frags_x * 16 * HEAD_DIM * sizeof(T)), From 3a3680685efa73f1fc69d4028b3298dfa954e8ce Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Wed, 3 Jun 2026 17:36:11 +0800 Subject: [PATCH 11/29] simplify code in produce_kv_blockwise_c16 --- .../gpu_ops/append_attn/append_attention_func.cuh | 10 +++++----- .../append_attn/multiquery_attention_c16_impl.cuh | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index be6d43d5e47..90c1ab8f8ed 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -774,7 +774,7 @@ __device__ __forceinline__ void produce_kv(smem_t smem, } template __device__ __forceinline__ void compute_qk(smem_t* q_smem, @@ -782,12 +782,12 @@ __device__ __forceinline__ void compute_qk(smem_t* q_smem, smem_t* k_smem, uint32_t* k_smem_offset_r, float (*s_frag)[num_frags_z][8]) { - constexpr uint32_t head_dim = num_frags_y * 16; + static_assert(head_dim % 16 == 0, ""); constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); uint32_t a_frag[num_frags_x][4], b_frag[4]; // compute q*k^T #pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { // k + for (uint32_t fy = 0; fy < head_dim / 16; ++fy) { // k #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // m q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx]); @@ -819,8 +819,8 @@ __device__ __forceinline__ void compute_qk(smem_t* q_smem, k_smem->advance_offset_by_column<2>(*k_smem_offset_r, fy) - num_frags_z * 16 * num_vecs_per_head; } - *q_smem_offset_r -= num_frags_y * 2; - *k_smem_offset_r -= num_frags_y * 2; + *q_smem_offset_r -= head_dim / 8; + *k_smem_offset_r -= head_dim / 8; } template ( + compute_qk( &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); // mask according to kv_idx and q_idx if (iter >= mask_check_iteration || sliding_window > 0) { @@ -598,7 +598,7 @@ __global__ void multi_query_append_attention_warp1_4_kernel( __syncthreads(); // s = qk - compute_qk( + compute_qk( &qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); // mask according to kv_idx and q_idx if (iter >= mask_check_iteration || sliding_window > 0) { From 0dc803df1f5c7e5eed17c074d6a8cefc54222331 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Wed, 3 Jun 2026 19:34:46 +0800 Subject: [PATCH 12/29] simplify code in produce_kv_blockwise_c16 --- custom_ops/gpu_ops/append_attention.cu | 16 +++++++++++----- custom_ops/gpu_ops/cpp_extensions.cc | 3 ++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index c1586945cc5..5c9bd05b1ff 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -91,7 +91,8 @@ void AppendAttentionKernel( const bool causal, const bool speculate_decoder, const int sliding_window, - const int sink_size = 0) { + const int sink_size = 0, + const bool only_do_attn = false) { typedef PDTraits traits_; typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; @@ -505,7 +506,8 @@ std::vector AppendAttention( const bool causal, const bool speculate_decoder, const int sliding_window, - const int sink_size = 0) { + const int sink_size = 0, + const bool only_do_attn = false) { AppendAttnMetaData meta_data; const auto& qkv_dims = qkv.dims(); @@ -638,7 +640,8 @@ std::vector AppendAttention( causal, speculate_decoder, sliding_window, - sink_size); + sink_size, + only_do_attn); }; phi::dtype::float16 fp16_dtype; @@ -888,7 +891,8 @@ std::vector> AppendAttentionInferShape( const bool causal, const bool speculate_decoder, const int sliding_window, - const int sink_size) { + const int sink_size, + const bool only_do_attn) { const int token_num = qkv_shape[0]; const int kv_num_heads = key_cache_shape[1]; int head_dim = key_cache_shape[3]; @@ -954,7 +958,8 @@ std::vector AppendAttentionInferDtype( const bool causal, const bool speculate_decoder, const int sliding_window, - const int sink_size) { + const int sink_size, + const bool only_do_attn) { if (compute_dtype == "bf16") { if (out_linear_in_scale > 0.0) { if (fabs(quant_max_bound - 127.0f) < 0.000001) { @@ -1161,6 +1166,7 @@ PD_BUILD_STATIC_OP(append_attention) "speculate_decoder: bool", "sliding_window: int", "sink_size: int", + "only_do_attn: bool", }) .SetKernelFn(PD_KERNEL(AppendAttention)) .SetInferShapeFn(PD_INFER_SHAPE(AppendAttentionInferShape)) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 6718e97d56c..911695e7412 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -130,7 +130,8 @@ std::vector AppendAttention( const bool causal, const bool speculate_decoder, const int sliding_window, - const int sink_size); + const int sink_size, + const bool only_do_attn); std::vector AppendAttentionWithOutput( const paddle::Tensor& qkv, From 9162f8d656ca9bbc1a8a23d9a8817da8e351e1b7 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 4 Jun 2026 10:56:20 +0800 Subject: [PATCH 13/29] simplify code in produce_kv_blockwise_c16 --- .../model_executor/layers/attention/ops/append_attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fastdeploy/model_executor/layers/attention/ops/append_attention.py b/fastdeploy/model_executor/layers/attention/ops/append_attention.py index 8b36ffa85b0..a53175b5090 100644 --- a/fastdeploy/model_executor/layers/attention/ops/append_attention.py +++ b/fastdeploy/model_executor/layers/attention/ops/append_attention.py @@ -85,6 +85,7 @@ def append_attention( sliding_window: int = 0, sink_size: int = 0, head_wise_full_hidden: int = 0, + only_do_attn: bool = False, ) -> paddle.Tensor: """ append_attention @@ -147,6 +148,7 @@ def append_attention( speculate_decoder, sliding_window, sink_size, + only_do_attn, ) sliding_window = 0 sink_size = 0 @@ -206,6 +208,7 @@ def append_attention( speculate_decoder, sliding_window, sink_size, + only_do_attn, ) if head_wise_full_hidden > 0: out_swa[:, :head_wise_full_hidden] = out[:, :head_wise_full_hidden] From 9d9eea3fe487a6075ca449211db167828d066e80 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 4 Jun 2026 16:31:02 +0800 Subject: [PATCH 14/29] support only_do_attn --- custom_ops/gpu_ops/append_attention.cu | 251 +++++++++++++------------ 1 file changed, 134 insertions(+), 117 deletions(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 5c9bd05b1ff..09fb505fa08 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -187,6 +187,8 @@ void AppendAttentionKernel( auto dispatch_EncoderWriteCacheWithRopeKernel = [&](auto temp_args) -> void { + if (only_do_attn) return; + DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, { EncoderWriteCacheWithRopeKernel( - meta_data, - qkv, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache), - q_norm_weight, - k_norm_weight, - rms_norm_eps); - } else { - SpeculateWriteCacheWithRoPEKernel( - meta_data, - qkv_out, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - batch_id_per_token, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache), - q_norm_weight, - k_norm_weight, - rms_norm_eps); - } - } else { - if (qkv_out_scales) { - DecoderWriteCacheWithRoPEKernel( - meta_data, - qkv, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache), - q_norm_weight, - k_norm_weight, - rms_norm_eps); + + if (!only_do_attn) { + DISPATCH_BOOL_DTYPE(enforce_fmul_rn, EnforceFmulRN, { + if (speculate_decoder) { + if (qkv_out_scales) { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + } else { + SpeculateWriteCacheWithRoPEKernel( + meta_data, + qkv_out, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + batch_id_per_token, + cu_seqlens_q, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + } } else { - DecoderWriteCacheWithRoPEKernel( - meta_data, - qkv_out, // [token_num, num_heads, head_dim] - seq_lens_decoder, - seq_lens_encoder, - cu_seqlens_q, - block_tables, - rotary_embs, - qkv_out_scales, - qkv_bias, - cache_k_quant_scales, - cache_v_quant_scales, - cache_k_zp, - cache_v_zp, - cache_quant_type_str, - use_neox_rotary_style, - rope_3d, - max_input_length, - exec_stream, - &qkv_out, - const_cast(&key_cache), - const_cast(&value_cache), - q_norm_weight, - k_norm_weight, - rms_norm_eps); + if (qkv_out_scales) { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + } else { + DecoderWriteCacheWithRoPEKernel( + meta_data, + qkv_out, // [token_num, num_heads, head_dim] + seq_lens_decoder, + seq_lens_encoder, + cu_seqlens_q, + block_tables, + rotary_embs, + qkv_out_scales, + qkv_bias, + cache_k_quant_scales, + cache_v_quant_scales, + cache_k_zp, + cache_v_zp, + cache_quant_type_str, + use_neox_rotary_style, + rope_3d, + max_input_length, + exec_stream, + &qkv_out, + const_cast(&key_cache), + const_cast(&value_cache), + q_norm_weight, + k_norm_weight, + rms_norm_eps); + } } - } - }) + }) + } if (out_linear_in_scale > 0.0) { switch (fmha_out.dtype()) { @@ -515,13 +520,25 @@ std::vector AppendAttention( meta_data.token_nums = qkv_dims[0]; meta_data.kv_num_heads = key_cache_dims[1]; meta_data.head_dims = key_cache_dims[3]; + + meta_data.head_dims_v = value_cache.dims()[3]; + + PADDLE_ENFORCE(key_cache_dims[0] == value_cache.dims()[0], "Unmatched shape"); + PADDLE_ENFORCE(key_cache_dims[1] == value_cache.dims()[1], "Unmatched shape"); + PADDLE_ENFORCE(key_cache_dims[2] == value_cache.dims()[2], "Unmatched shape"); + // TODO: trick method support c4, add attr head_dims in the future if (cache_quant_type_str == "cache_int4_zp") { meta_data.head_dims *= 2; } - const int total_num_head = - qkv_dims[qkv_dims.size() - 1] / meta_data.head_dims; - meta_data.q_num_heads = total_num_head - 2 * meta_data.kv_num_heads; + + meta_data.q_num_heads = + (qkv_dims[qkv_dims.size() - 1] - + meta_data.kv_num_heads * (meta_data.head_dims + meta_data.head_dims_v)) / + meta_data.head_dims; + + std::cout << "meta_data.q_num_heads" << meta_data.q_num_heads << std::endl; + std::cout << "meta_data.head_dims_v" << meta_data.head_dims_v << std::endl; meta_data.max_blocks_per_seq = block_tables.dims()[1]; meta_data.block_size = key_cache.dims()[2]; @@ -563,12 +580,12 @@ std::vector AppendAttention( if (out_linear_in_scale > 0.0) { if (fabs(quant_max_bound - 127.0f) < 0.000001) { fmha_out = paddle::zeros( - {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v}, paddle::DataType::INT8, qkv.place()); } else if (fabs(quant_max_bound - 448.0f) < 0.000001) { fmha_out = paddle::zeros( - {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v}, paddle::DataType::FLOAT8_E4M3FN, qkv.place()); } else { @@ -576,7 +593,7 @@ std::vector AppendAttention( } } else { fmha_out = paddle::zeros( - {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims}, + {meta_data.token_nums, meta_data.q_num_heads * meta_data.head_dims_v}, dtype_id, qkv.place()); } From 6daeb2ecd9ca2c0a97969708a3db69a679ef8dfe Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 4 Jun 2026 16:33:17 +0800 Subject: [PATCH 15/29] support only_do_attn --- custom_ops/gpu_ops/append_attention.cu | 3 --- 1 file changed, 3 deletions(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 09fb505fa08..00622556183 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -537,9 +537,6 @@ std::vector AppendAttention( meta_data.kv_num_heads * (meta_data.head_dims + meta_data.head_dims_v)) / meta_data.head_dims; - std::cout << "meta_data.q_num_heads" << meta_data.q_num_heads << std::endl; - std::cout << "meta_data.head_dims_v" << meta_data.head_dims_v << std::endl; - meta_data.max_blocks_per_seq = block_tables.dims()[1]; meta_data.block_size = key_cache.dims()[2]; meta_data.batch_size = seq_lens_this_time.dims()[0]; From a18cb11f3ee5ce51e4301d347e916083a07270ad Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 4 Jun 2026 16:38:59 +0800 Subject: [PATCH 16/29] support only_do_attn --- custom_ops/gpu_ops/append_attention.cu | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 00622556183..09bf6157161 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -532,10 +532,14 @@ std::vector AppendAttention( meta_data.head_dims *= 2; } - meta_data.q_num_heads = - (qkv_dims[qkv_dims.size() - 1] - - meta_data.kv_num_heads * (meta_data.head_dims + meta_data.head_dims_v)) / - meta_data.head_dims; + auto q_size = + qkv_dims[qkv_dims.size() - 1] - + meta_data.kv_num_heads * (meta_data.head_dims + meta_data.head_dims_v); + + PADDLE_ENFORCE(q_size % meta_data.head_dims == 0, "Unmatched shape"); + PADDLE_ENFORCE(q_size > 0, "Unmatched shape"); + + meta_data.q_num_heads = q_size / meta_data.head_dims; meta_data.max_blocks_per_seq = block_tables.dims()[1]; meta_data.block_size = key_cache.dims()[2]; From c23d1a43d9cc7bc38059eee7a221c5a1ef2551db Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 4 Jun 2026 16:42:23 +0800 Subject: [PATCH 17/29] support only_do_attn --- custom_ops/gpu_ops/append_attention.cu | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 09bf6157161..0c2edcfb4fd 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -545,6 +545,10 @@ std::vector AppendAttention( meta_data.block_size = key_cache.dims()[2]; meta_data.batch_size = seq_lens_this_time.dims()[0]; + PADDLE_ENFORCE( + max_input_length == meta_data.block_size * meta_data.max_blocks_per_seq, + "Unmatched shape"); + // template dtype generation phi::DataType dtype_id; switch (qkv.dtype()) { From 459e85e906e5c7bf16a0b8da34d4167d6899ccc9 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 4 Jun 2026 16:51:19 +0800 Subject: [PATCH 18/29] support only_do_attn --- custom_ops/gpu_ops/append_attention.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 0c2edcfb4fd..c5f64992159 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -418,7 +418,7 @@ void AppendAttentionKernel( decoder_tile_ids_per_batch, decoder_num_blocks_data, decoder_block_shape_q, - max_kv_len_this_time, + -1, // useless !speculate_decoder, !speculate_decoder, exec_stream); @@ -431,7 +431,7 @@ void AppendAttentionKernel( decoder_tile_ids_per_batch, decoder_num_blocks_data, decoder_block_shape_q, - max_kv_len_this_time, + -1, // useless !speculate_decoder, !speculate_decoder, exec_stream); @@ -445,7 +445,7 @@ void AppendAttentionKernel( decoder_tile_ids_per_batch, decoder_num_blocks_data, decoder_block_shape_q, - max_kv_len_this_time, + -1, // useless !speculate_decoder, !speculate_decoder, exec_stream); From 8451480e39cbb88c57a33527905c60d3cf310c99 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 4 Jun 2026 16:52:05 +0800 Subject: [PATCH 19/29] support only_do_attn --- custom_ops/gpu_ops/append_attention.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index c5f64992159..5f4159c1d3a 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -102,7 +102,6 @@ void AppendAttentionKernel( const int max_dec_len_this_time = set_max_lengths.data()[2]; const int max_enc_dec_len_this_time = set_max_lengths.data()[3]; const int max_just_dec_len_this_time = set_max_lengths.data()[4]; - const int max_kv_len_this_time = set_max_lengths.data()[5]; auto main_stream = qkv.stream(); static cudaEvent_t main_event; From 899ceb7e10085de1feb76ad951b8ec8da2fd51b3 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 4 Jun 2026 17:25:23 +0800 Subject: [PATCH 20/29] support only_do_attn --- custom_ops/gpu_ops/append_attention.cu | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 5f4159c1d3a..15c1fb46fe9 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -546,7 +546,12 @@ std::vector AppendAttention( PADDLE_ENFORCE( max_input_length == meta_data.block_size * meta_data.max_blocks_per_seq, - "Unmatched shape"); + "Unmatched shape: ", + max_input_length, + " ", + meta_data.block_size, + " ", + meta_data.max_blocks_per_seq); // template dtype generation phi::DataType dtype_id; From f8d7da5ae6d3513af6748b7f8f255f9d28e3e729 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 4 Jun 2026 18:07:22 +0800 Subject: [PATCH 21/29] support only_do_attn --- custom_ops/gpu_ops/append_attention.cu | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 15c1fb46fe9..342f7d3e915 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -97,7 +97,6 @@ void AppendAttentionKernel( typedef typename traits_::DataType DataType_; typedef typename traits_::data_t data_t; - const int max_len_this_time = set_max_lengths.data()[0]; const int max_enc_len_this_time = set_max_lengths.data()[1]; const int max_dec_len_this_time = set_max_lengths.data()[2]; const int max_enc_dec_len_this_time = set_max_lengths.data()[3]; @@ -544,14 +543,10 @@ std::vector AppendAttention( meta_data.block_size = key_cache.dims()[2]; meta_data.batch_size = seq_lens_this_time.dims()[0]; - PADDLE_ENFORCE( - max_input_length == meta_data.block_size * meta_data.max_blocks_per_seq, - "Unmatched shape: ", - max_input_length, - " ", - meta_data.block_size, - " ", - meta_data.max_blocks_per_seq); + // PADDLE_ENFORCE( + // max_input_length == meta_data.block_size * + // meta_data.max_blocks_per_seq, "Unmatched shape: ", max_input_length, " + // ", meta_data.block_size, " ", meta_data.max_blocks_per_seq); // template dtype generation phi::DataType dtype_id; @@ -654,7 +649,7 @@ std::vector AppendAttention( cache_quant_type_str, use_neox_rotary_style, rope_3d, - max_input_length, + meta_data.block_size * meta_data.max_blocks_per_seq, quant_max_bound, quant_min_bound, out_linear_in_scale, From c28b6d649390f3856066702f242319cacfac547c Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 4 Jun 2026 18:50:48 +0800 Subject: [PATCH 22/29] support only_do_attn --- custom_ops/gpu_ops/append_attention.cu | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/custom_ops/gpu_ops/append_attention.cu b/custom_ops/gpu_ops/append_attention.cu index 342f7d3e915..4f02e51e111 100644 --- a/custom_ops/gpu_ops/append_attention.cu +++ b/custom_ops/gpu_ops/append_attention.cu @@ -98,7 +98,6 @@ void AppendAttentionKernel( typedef typename traits_::data_t data_t; const int max_enc_len_this_time = set_max_lengths.data()[1]; - const int max_dec_len_this_time = set_max_lengths.data()[2]; const int max_enc_dec_len_this_time = set_max_lengths.data()[3]; const int max_just_dec_len_this_time = set_max_lengths.data()[4]; @@ -649,7 +648,7 @@ std::vector AppendAttention( cache_quant_type_str, use_neox_rotary_style, rope_3d, - meta_data.block_size * meta_data.max_blocks_per_seq, + max_input_length, quant_max_bound, quant_min_bound, out_linear_in_scale, From 9edfe263e0d615e3fd4776e149b78d967af54773 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 4 Jun 2026 20:21:26 +0800 Subject: [PATCH 23/29] support only_do_attn --- .../multiquery_attention_c16_impl.cuh | 138 ++++++------------ 1 file changed, 45 insertions(+), 93 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 3991b1e63ff..473e89c3f95 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -27,8 +27,7 @@ template + typename OutT = T> __global__ void multi_query_append_attention_kernel( const T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * // head_dim] @@ -83,18 +82,7 @@ __global__ void multi_query_append_attention_kernel( return; } - uint32_t kv_len = seq_lens_kv[batch_id]; - if (ENABLE_PREFILL) { - kv_len += q_len; - if (kv_len <= 0) { - return; - } - } else { - if (kv_len <= 0) { - return; - } - kv_len += q_len; - } + const uint32_t kv_len = seq_lens_kv[batch_id] + q_len; const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); if (chunk_idx >= num_chunks_this_seq) { @@ -134,17 +122,9 @@ __global__ void multi_query_append_attention_kernel( T *o_base_ptr_T = nullptr; OutT *o_base_ptr_int8 = nullptr; if constexpr (partition_kv) { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } + o_base_ptr_T = tmp_workspace + q_start_seq_id * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); } else { o_base_ptr_int8 = out + o_offset; } @@ -173,28 +153,27 @@ __global__ void multi_query_append_attention_kernel( q_smem_inplace_multiply_sm_scale(&qo_smem, scale); - smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), - v_smem(smem + (NUM_WARPS * num_frags_x + num_frags_z) * 16 * HEAD_DIM * - sizeof(T)); + smem_t k_smem(smem + num_rows_per_block * HEAD_DIM * sizeof(T)), + v_smem(smem + (num_rows_per_block + BLOCK_SIZE) * HEAD_DIM * sizeof(T)); const uint32_t num_iterations = div_up( CAUSAL - ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), - chunk_start))) + ? min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + + div_up((tile_id + 1) * num_rows_per_block, GROUP_SIZE), + chunk_start)) : chunk_len, - num_frags_z * 16); + BLOCK_SIZE); const uint32_t mask_check_iteration = - (CAUSAL ? (min(chunk_len, - sub_if_greater_or_zero( - kv_len - q_len + - tile_id * num_rows_per_block / GROUP_SIZE, - chunk_start))) + (CAUSAL + ? min(chunk_len, + sub_if_greater_or_zero( + kv_len - q_len + tile_id * num_rows_per_block / GROUP_SIZE, + chunk_start)) : mask_offset ? 0 : chunk_len) / - (num_frags_z * 16); + BLOCK_SIZE; uint32_t k_smem_offset_r = smem_t::get_permuted_offset( 8 * (tid / 16) + tid % 8, (tid % 16) / 8); @@ -263,7 +242,7 @@ __global__ void multi_query_append_attention_kernel( s_frag, o_frag, m_frag, d_frag); __syncthreads(); - kv_idx_base += num_frags_z * 16; + kv_idx_base += BLOCK_SIZE; block_id = __ldg(&block_table_now[kv_idx_base / BLOCK_SIZE]); if (block_id < 0) { block_id = 0; @@ -372,18 +351,8 @@ __global__ void multi_query_append_attention_kernel( const uint32_t qo_head_idx = q_head_idx + qo_idx_now % GROUP_SIZE; const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = - (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } + const uint32_t offset = + (qo_idx * num_chunks + chunk_idx) * q_num_heads + qo_head_idx; tmp_m[offset] = m_frag[fx][j]; tmp_d[offset] = d_frag[fx][j]; } @@ -824,6 +793,7 @@ void MultiQueryAppendAttention( constexpr uint32_t smem_size = (num_warps * num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * HEAD_DIM * sizeof(T); + auto split_kv_kernel = multi_query_append_attention_kernel; + OUT_NV_TYPE>; if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(split_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -866,20 +835,18 @@ void MultiQueryAppendAttention( dim3 grids(num_blocks_x_cpu, num_chunks, kv_num_heads); dim3 blocks(32, num_warps); if (num_chunks <= 1 || force_no_partition) { - auto nosplit_kv_kernel = - multi_query_append_attention_kernel; + auto nosplit_kv_kernel = multi_query_append_attention_kernel; if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(nosplit_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -927,30 +894,15 @@ void MultiQueryAppendAttention( } else { phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (ENABLE_PREFILL) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(token_num * num_chunks * num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(token_num * num_chunks * num_heads)); launchWithPdlWhenEnabled( split_kv_kernel, From 29914ab9c28fe60f6bd1e9e03cff61316790d891 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 4 Jun 2026 22:10:21 +0800 Subject: [PATCH 24/29] support only_do_attn --- .../append_attn/multiquery_attention_c16_impl.cuh | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 473e89c3f95..e3864222f94 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -50,7 +50,6 @@ __global__ void multi_query_append_attention_kernel( const float quant_min_bound, const float in_scale, const uint32_t chunk_size, - const int num_blocks_x_cpu, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, // num_heads, head_dim] float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] @@ -72,16 +71,7 @@ __global__ void multi_query_append_attention_kernel( const uint32_t num_rows_per_block = NUM_WARPS * num_frags_x * 16; const int *block_table_now = block_table + batch_id * max_block_num_per_seq; - // When cudagraph capture prefill, may launch more gridDim.x - if (btid >= static_cast(num_blocks_x_cpu)) { - return; - } - const uint32_t q_len = seq_lens[batch_id]; - if (q_len <= 0) { - return; - } - const uint32_t kv_len = seq_lens_kv[batch_id] + q_len; const uint32_t num_chunks_this_seq = div_up(kv_len, chunk_size); @@ -883,7 +873,6 @@ void MultiQueryAppendAttention( quant_min_bound, in_scale, chunk_size, - num_blocks_x_cpu, nullptr, nullptr, nullptr, @@ -935,7 +924,6 @@ void MultiQueryAppendAttention( quant_min_bound, in_scale, chunk_size, - num_blocks_x_cpu, reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), static_cast(tmp_d->ptr()), From 3be8f990b4dd367ec999b05c68a3d2283c6532a9 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 4 Jun 2026 22:17:50 +0800 Subject: [PATCH 25/29] support only_do_attn --- .../gpu_ops/append_attn/multiquery_attention_c16_impl.cuh | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index e3864222f94..c18a71eea36 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -772,7 +772,6 @@ void MultiQueryAppendAttention( constexpr uint32_t NUM_WARP_KV = num_warps / NUM_WARP_Q; constexpr uint32_t num_frags_x = BLOCK_SHAPE_Q / (16 * NUM_WARP_Q); // 1 or 2 constexpr uint32_t num_frags_y = HEAD_DIM / 16; - constexpr uint32_t num_qrow_per_block = NUM_WARP_Q * num_frags_x * 16; auto *allocator = paddle::GetAllocator(qkv.place()); @@ -781,8 +780,7 @@ void MultiQueryAppendAttention( if constexpr (NUM_WARP_Q == 4) { constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; constexpr uint32_t smem_size = - (num_warps * num_frags_x + NUM_WARP_KV * num_frags_z * 2) * 16 * - HEAD_DIM * sizeof(T); + (num_rows_per_block + BLOCK_SIZE * 2) * HEAD_DIM * sizeof(T); auto split_kv_kernel = multi_query_append_attention_kernel Date: Thu, 4 Jun 2026 22:38:07 +0800 Subject: [PATCH 26/29] support only_do_attn --- .../gpu_ops/append_attn/multiquery_attention_c16_impl.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index c18a71eea36..0bdf88e0a56 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -780,7 +780,7 @@ void MultiQueryAppendAttention( if constexpr (NUM_WARP_Q == 4) { constexpr uint32_t num_frags_z = BLOCK_SIZE / 16; constexpr uint32_t smem_size = - (num_rows_per_block + BLOCK_SIZE * 2) * HEAD_DIM * sizeof(T); + (BLOCK_SHAPE_Q + BLOCK_SIZE * 2) * HEAD_DIM * sizeof(T); auto split_kv_kernel = multi_query_append_attention_kernel Date: Sat, 6 Jun 2026 10:40:12 +0800 Subject: [PATCH 27/29] support only_do_attn --- custom_ops/gpu_ops/append_attn/append_attention_func.cuh | 9 ++++----- .../append_attn/multiquery_attention_c16_impl.cuh | 5 ++--- .../gpu_ops/append_attn/multiquery_attention_c4_impl.cuh | 5 ++--- .../gpu_ops/append_attn/multiquery_attention_c8_impl.cuh | 5 ++--- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh index 90c1ab8f8ed..b7066172ff8 100644 --- a/custom_ops/gpu_ops/append_attn/append_attention_func.cuh +++ b/custom_ops/gpu_ops/append_attn/append_attention_func.cuh @@ -163,7 +163,6 @@ __device__ __forceinline__ void load_q_global_smem_multi_warps( template __device__ __forceinline__ void load_q_global_smem( @@ -175,6 +174,7 @@ __device__ __forceinline__ void load_q_global_smem( const uint32_t qo_h_stride) { constexpr uint32_t num_vecs_per_head = HEAD_DIM / num_elems_per_128b(); + static_assert(HEAD_DIM % 64 == 0, ""); const uint32_t tx = threadIdx.x, ty = threadIdx.y; uint32_t q_smem_offset_w = // [NUM_WARP_Q, num_frags_x, 16, head_dim] @@ -193,7 +193,7 @@ __device__ __forceinline__ void load_q_global_smem( const T* q_ptr = q_ptr_base + n_offset * qo_n_stride + h_offset * qo_h_stride; #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { + for (uint32_t fyo = 0; fyo < HEAD_DIM / 64; ++fyo) { q_smem->load_128b_async( q_smem_offset_w, q_ptr, n_offset < qo_upper_bound); q_smem_offset_w = @@ -202,7 +202,7 @@ __device__ __forceinline__ void load_q_global_smem( } q_smem_offset_w = q_smem->advance_offset_by_row<4, num_vecs_per_head>(q_smem_offset_w) - - 2 * num_frags_y; // num_frags_y / 4 * 8 + HEAD_DIM / 8; } } } @@ -228,7 +228,7 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale_multi_warps( } } -template +template __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale( smem_t* q_smem, // [num_frags_x * 16, num_frags_y * 16] const float sm_scale) { @@ -236,7 +236,6 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale( using LoadT = AlignedVector; LoadT tmp_vec; const uint32_t tx = threadIdx.x, ty = threadIdx.y; - constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t num_vecs_per_head = head_dim / num_elems_per_128b(); #pragma unroll diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 0bdf88e0a56..cb51f35fe4b 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -129,7 +129,7 @@ __global__ void multi_query_append_attention_kernel( #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaGridDependencySynchronize(); #endif - load_q_global_smem( + load_q_global_smem( q_base_ptr, &qo_smem, q_base_seq_id_this_block, @@ -140,8 +140,7 @@ __global__ void multi_query_append_attention_kernel( wait_group<0>(); __syncthreads(); - q_smem_inplace_multiply_sm_scale(&qo_smem, - scale); + q_smem_inplace_multiply_sm_scale(&qo_smem, scale); smem_t k_smem(smem + num_rows_per_block * HEAD_DIM * sizeof(T)), v_smem(smem + (num_rows_per_block + BLOCK_SIZE) * HEAD_DIM * sizeof(T)); diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh index 8d4111a0a29..ee1821f13c1 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c4_impl.cuh @@ -190,7 +190,7 @@ __global__ void multi_query_append_attention_c4_kernel( uint32_t q_smem_offset_r = smem_t::get_permuted_offset( wid * num_frags_x * 16 + tid % 16, tid / 16); - load_q_global_smem( + load_q_global_smem( q_base_ptr, &qo_smem, q_base_seq_id_this_block, @@ -201,8 +201,7 @@ __global__ void multi_query_append_attention_c4_kernel( wait_group<0>(); __syncthreads(); - q_smem_inplace_multiply_sm_scale(&qo_smem, - scale); + q_smem_inplace_multiply_sm_scale(&qo_smem, scale); T cache_k_scale_frag[num_frags_y][4]; T cache_k_zp_frag[num_frags_y][4]; diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh index 611b5d66435..ba103a6e2f4 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c8_impl.cuh @@ -198,7 +198,7 @@ __global__ void multi_query_append_attention_c8_kernel( uint32_t q_smem_offset_r = smem_t::get_permuted_offset( wid * num_frags_x * 16 + tid % 16, tid / 16); // 16 * 16 - load_q_global_smem( + load_q_global_smem( q_base_ptr, &qo_smem, q_base_seq_id_this_block, @@ -209,8 +209,7 @@ __global__ void multi_query_append_attention_c8_kernel( wait_group<0>(); __syncthreads(); - q_smem_inplace_multiply_sm_scale(&qo_smem, - scale); + q_smem_inplace_multiply_sm_scale(&qo_smem, scale); smem_t k_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T)), v_smem(smem + NUM_WARPS * num_frags_x * 16 * HEAD_DIM * sizeof(T) + num_frags_z * 16 * HEAD_DIM * sizeof(CacheT)); From ec9e650a06873c305bdf8225f96f34ec786d6e07 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Sat, 6 Jun 2026 15:50:03 +0800 Subject: [PATCH 28/29] support only_do_attn --- custom_ops/gpu_ops/append_attn/template_config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/custom_ops/gpu_ops/append_attn/template_config.json b/custom_ops/gpu_ops/append_attn/template_config.json index b2590586206..b1932536859 100644 --- a/custom_ops/gpu_ops/append_attn/template_config.json +++ b/custom_ops/gpu_ops/append_attn/template_config.json @@ -90,7 +90,7 @@ ], "dispatch_params": { "GROUP_SIZE": [1, 2, 4, 5, 6, 7, 8, 12, 14, 16, 24], - "HEAD_DIM": [64,128], + "HEAD_DIM": [64,128,192], "BLOCK_SIZE": [64], "CAUSAL": [0, 1], "BLOCK_SHAPE_Q": [16, 32, 64, 128], From f49af7a6d53470244be5ce6a0cc95aeedf512f56 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Sat, 6 Jun 2026 18:05:15 +0800 Subject: [PATCH 29/29] support only_do_attn --- .../gpu_ops/append_attn/multiquery_attention_c16_impl.cuh | 3 --- 1 file changed, 3 deletions(-) diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index cb51f35fe4b..c8399a0e1f7 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -389,7 +389,6 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const float quant_min_bound, const float in_scale, const uint32_t chunk_size, - const int num_blocks_x_cpu, T *__restrict__ tmp_workspace, // split kv [token_num, num_chunks, // num_heads, head_dim] float *__restrict__ tmp_m, // [token_num, num_chunks, num_heads] @@ -1079,7 +1078,6 @@ void MultiQueryAppendAttention( quant_min_bound, in_scale, chunk_size, - num_blocks_x_cpu, nullptr, nullptr, nullptr, @@ -1136,7 +1134,6 @@ void MultiQueryAppendAttention( quant_min_bound, in_scale, chunk_size, - num_blocks_x_cpu, reinterpret_cast(tmp_workspace->ptr()), static_cast(tmp_m->ptr()), static_cast(tmp_d->ptr()),