Skip to content

Commit

Permalink
Custom PA Partition size 256 to improve performance (#238)
Browse files Browse the repository at this point in the history
* add option to adjust partition size

* changed CPA partition size to 256 in rocm attention backend

* support context length 128K with partition size 256
  • Loading branch information
sanyalington authored Oct 22, 2024
1 parent e0b6bb4 commit 1eefd1e
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 36 deletions.
89 changes: 67 additions & 22 deletions csrc/rocm/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -316,12 +316,14 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
int vphysical_blocks[VBLOCKS];

const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE;
if constexpr (GQA_RATIO < 12) {
#pragma unroll
for (int b = 0; b < VBLOCKS; b++) {
const int vblock_idx = warp_start_block_idx + b;
const int vblock_idx_ctx =
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
vphysical_blocks[b] = block_table[vblock_idx_ctx];
for (int b = 0; b < VBLOCKS; b++) {
const int vblock_idx = warp_start_block_idx + b;
const int vblock_idx_ctx =
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
vphysical_blocks[b] = block_table[vblock_idx_ctx];
}
}

// each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
Expand Down Expand Up @@ -379,6 +381,17 @@ __global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
}
}

// fetch vphysical block numbers up front
if constexpr (GQA_RATIO >= 12) {
#pragma unroll
for (int b = 0; b < VBLOCKS; b++) {
const int vblock_idx = warp_start_block_idx + b;
const int vblock_idx_ctx =
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
vphysical_blocks[b] = block_table[vblock_idx_ctx];
}
}

const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr);
Expand Down Expand Up @@ -1023,8 +1036,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
context_lens_ptr, max_num_partitions, fp8_out_scale_ptr);

template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT,
int PARTITION_SIZE = 512>
int BLOCK_SIZE, int HEAD_SIZE, typename OUTT, int PARTITION_SIZE>
void paged_attention_custom_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
Expand Down Expand Up @@ -1069,7 +1081,6 @@ void paged_attention_custom_launcher(
const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE);
assert(max_num_partitions <= 256);

constexpr int NTHR = PARTITION_SIZE;
dim3 grid(num_seqs, max_num_partitions, num_kv_heads);
Expand Down Expand Up @@ -1129,9 +1140,6 @@ void paged_attention_custom_launcher(
TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio);
break;
}
// dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG);
// dim3 block2(1024);
// LAUNCH_CUSTOM_ATTENTION2;

// reduction kernel is only required if max_context_len > partition size,
// otherwise main kernel writes directly to final output
Expand All @@ -1142,6 +1150,7 @@ void paged_attention_custom_launcher(
dim3 reduce_grid(num_heads, num_seqs);
dim3 reduce_block(head_size);
const int npar_loops = DIVIDE_ROUND_UP(max_num_partitions, WARP_SIZE);
// support upto 8*64*256=128K context length
switch (npar_loops) {
case 1:
LAUNCH_CUSTOM_REDUCTION(1);
Expand All @@ -1155,27 +1164,63 @@ void paged_attention_custom_launcher(
case 4:
LAUNCH_CUSTOM_REDUCTION(4);
break;
case 5:
LAUNCH_CUSTOM_REDUCTION(5);
break;
case 6:
LAUNCH_CUSTOM_REDUCTION(6);
break;
case 7:
LAUNCH_CUSTOM_REDUCTION(7);
break;
case 8:
LAUNCH_CUSTOM_REDUCTION(8);
break;
default:
TORCH_CHECK(false, "Unsupported npar_loops: ", npar_loops);
break;
}
}
}

#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
PSIZE) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, \
PSIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
alibi_slopes, k_scale, v_scale, fp8_out_scale);

#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (fp8_out_scale) { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, uint8_t); \
} else { \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \
#define CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
OUTT) \
switch (partition_size) { \
case 256: \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 256); \
break; \
case 512: \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, OUTT, 512); \
break; \
default: \
TORCH_CHECK(false, "Unsupported partition size: ", partition_size); \
break; \
}

#if defined(__HIPCC__) && defined(__gfx90a__)
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (fp8_out_scale) { \
TORCH_CHECK(false, "fp8 out scale unsupported for gfx90a"); \
} else { \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \
}
#else
#define CALL_CUSTOM_LAUNCHER_OUT(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
if (fp8_out_scale) { \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, \
uint8_t); \
} else { \
CALL_CUSTOM_LAUNCHER_PSIZE(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE, T); \
}
#endif
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (block_size) { \
case 16: \
Expand Down Expand Up @@ -1219,7 +1264,7 @@ void paged_attention(
int64_t block_size, int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const c10::optional<torch::Tensor>& fp8_out_scale) {
const c10::optional<torch::Tensor>& fp8_out_scale, int64_t partition_size) {
const int head_size = query.size(2);
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Half) {
Expand Down
20 changes: 9 additions & 11 deletions csrc/rocm/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,12 @@ void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c,
const int64_t N_in, const int64_t CuCount);

void paged_attention(torch::Tensor& out, torch::Tensor& exp_sums,
torch::Tensor& max_logits, torch::Tensor& tmp_out,
torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads,
double scale, torch::Tensor& block_tables,
torch::Tensor& context_lens, int64_t block_size,
int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale,
double v_scale,
const c10::optional<torch::Tensor>& fp8_out_scale);
void paged_attention(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
int64_t block_size, int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale,
const c10::optional<torch::Tensor>& fp8_out_scale, int64_t partition_size);
3 changes: 2 additions & 1 deletion csrc/rocm/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, rocm_ops) {
" Tensor? alibi_slopes,"
" str kv_cache_dtype,"
" float k_scale, float v_scale,"
" Tensor? fp8_out_scale) -> ()");
" Tensor? fp8_out_scale,"
" int partition_size) -> ()");
rocm_ops.impl("paged_attention", torch::kCUDA, &paged_attention);
rocm_ops.def(
"wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in,"
Expand Down
3 changes: 2 additions & 1 deletion vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,14 @@ def paged_attention_rocm(
k_scale: float,
v_scale: float,
fp8_out_scale: Optional[torch.Tensor],
partition_size: int,
) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads,
scale, block_tables, seq_lens,
block_size, max_seq_len, alibi_slopes,
kv_cache_dtype, k_scale, v_scale,
fp8_out_scale)
fp8_out_scale, partition_size)


# pos encoding ops
Expand Down
3 changes: 2 additions & 1 deletion vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

logger = init_logger(__name__)

_PARTITION_SIZE_ROCM = 512
_PARTITION_SIZE_ROCM = 256
_ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName


Expand Down Expand Up @@ -798,6 +798,7 @@ def forward(
k_scale,
v_scale,
fp8_out_scale if cpa_fp8_out else None,
_PARTITION_SIZE_ROCM,
)
if cpa_fp8_out:
return out.view(num_seqs, num_heads * head_size)
Expand Down

0 comments on commit 1eefd1e

Please sign in to comment.