Skip to content

Commit

Permalink
[AMD] Add support for GGUF quantization on ROCm (#10254)
Browse files Browse the repository at this point in the history
  • Loading branch information
kliuae authored Nov 23, 2024
1 parent 02a43f8 commit 7c25fe4
Show file tree
Hide file tree
Showing 11 changed files with 234 additions and 211 deletions.
1 change: 0 additions & 1 deletion .buildkite/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ if [[ $commands == *" kernels "* ]]; then
--ignore=kernels/test_encoder_decoder_attn.py \
--ignore=kernels/test_flash_attn.py \
--ignore=kernels/test_flashinfer.py \
--ignore=kernels/test_gguf.py \
--ignore=kernels/test_int8_quant.py \
--ignore=kernels/test_machete_gemm.py \
--ignore=kernels/test_mamba_ssm.py \
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ set(VLLM_EXT_SRC
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/common.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/prepare_inputs/advance_step.cu"
"csrc/torch_bindings.cpp")
Expand Down Expand Up @@ -237,7 +238,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu")
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel,
int64_t thx, int64_t thy);

torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
#endif

torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n);
Expand All @@ -138,6 +139,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type,
int64_t row);

#ifndef USE_ROCM
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability);

void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
Expand Down
17 changes: 16 additions & 1 deletion csrc/quantization/gguf/ggml-common.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
#define WARP_SIZE 32
#define WARP_SIZE_GGUF 32
#define K_SCALE_SIZE 12
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE 256
Expand Down Expand Up @@ -1112,4 +1112,19 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#endif
return c;
}

static __device__ __forceinline__ uint32_t __vcmpeq4(const uint32_t a, const uint32_t b) {
uint32_t neq = a^b;
return !(neq & 0xff000000) * 0xff000000 |
!(neq & 0x00ff0000) * 0x00ff0000 |
!(neq & 0x0000ff00) * 0x0000ff00 |
!(neq & 0x000000ff) * 0x000000ff;
}

static __device__ __forceinline__ uint32_t __vsub4(const uint32_t a, const uint32_t b) {
return (static_cast<uint8_t>(((a & 0xff000000) >> 24) - ((b & 0xff000000) >> 24)) << 24) +
(static_cast<uint8_t>(((a & 0x00ff0000) >> 16) - ((b & 0x00ff0000) >> 16)) << 16) +
(static_cast<uint8_t>(((a & 0x0000ff00) >> 8) - ((b & 0x0000ff00) >> 8)) << 8) +
(static_cast<uint8_t>(((a & 0x000000ff) >> 0) - ((b & 0x000000ff) >> 0)) << 0);
}
#endif // defined(USE_ROCM)
6 changes: 4 additions & 2 deletions csrc/quantization/gguf/gguf_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>

#include "cuda_compat.h"

#include "ggml-common.h"
#include "vecdotq.cuh"
#include "dequantize.cuh"
Expand Down Expand Up @@ -32,8 +34,8 @@ static __global__ void quantize_q8_1(const half* __restrict__ x,

#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
amax = fmaxf(amax, VLLM_SHFL_XOR_SYNC_WIDTH(amax, mask, 32));
sum += VLLM_SHFL_XOR_SYNC_WIDTH(sum, mask, 32);
}

const float d = amax / 127;
Expand Down
70 changes: 35 additions & 35 deletions csrc/quantization/gguf/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ static __device__ __forceinline__ void mul_mat_q(

const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_col_y = nrows_y / QK8_1;
const int blocks_per_warp = WARP_SIZE / qi;
const int blocks_per_warp = WARP_SIZE_GGUF / qi;

const int & ncols_dst = ncols_y;

Expand All @@ -27,10 +27,10 @@ static __device__ __forceinline__ void mul_mat_q(

allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);

__shared__ int tile_y_qs[mmq_x * WARP_SIZE];
__shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1];
__shared__ int tile_y_qs[mmq_x * WARP_SIZE_GGUF];
__shared__ half2 tile_y_ds[mmq_x * WARP_SIZE_GGUF/QI8_1];

float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}};
float sum[mmq_y/WARP_SIZE_GGUF][mmq_x/nwarps] = {{0.0f}};

for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {

Expand All @@ -39,26 +39,26 @@ static __device__ __forceinline__ void mul_mat_q(

#pragma unroll
for (int ir = 0; ir < qr; ++ir) {
const int kqs = ir*WARP_SIZE + threadIdx.x;
const int kqs = ir*WARP_SIZE_GGUF + threadIdx.x;
const int kbxd = kqs / QI8_1;

#pragma unroll
for (int i = 0; i < mmq_x; i += nwarps) {
const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE;
const int index_y = (threadIdx.y + i) * WARP_SIZE_GGUF + kqs % WARP_SIZE_GGUF;
tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1);
}

#pragma unroll
for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) {
const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x;
const int kby = threadIdx.x % (WARP_SIZE/QI8_1);
const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE_GGUF/QI8_1)) % mmq_x;
const int kby = threadIdx.x % (WARP_SIZE_GGUF/QI8_1);
const int col_y_eff = min(col_y_0 + ids, ncols_y-1);

// if the sum is not needed it's faster to transform the scale to f32 ahead of time
const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds;
half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby];
const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE_GGUF/QI8_1) + kby].ds;
half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE_GGUF/QI8_1) + kby];
if (need_sum) {
*dsi_dst = *dsi_src;
} else {
Expand All @@ -70,12 +70,12 @@ static __device__ __forceinline__ void mul_mat_q(
__syncthreads();

// #pragma unroll // unrolling this loop causes too much register pressure
for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) {
for (int k = ir*WARP_SIZE_GGUF/qr; k < (ir+1)*WARP_SIZE_GGUF/qr; k += vdr) {
#pragma unroll
for (int j = 0; j < mmq_x; j += nwarps) {
#pragma unroll
for (int i = 0; i < mmq_y; i += WARP_SIZE) {
sum[i/WARP_SIZE][j/nwarps] += vec_dot(
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
sum[i/WARP_SIZE_GGUF][j/nwarps] += vec_dot(
tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
threadIdx.x + i, threadIdx.y + j, k);
}
Expand All @@ -93,12 +93,12 @@ static __device__ __forceinline__ void mul_mat_q(
}

#pragma unroll
for (int i = 0; i < mmq_y; i += WARP_SIZE) {
for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) {
const int row_dst = row_dst_0 + threadIdx.x + i;
if (row_dst >= nrows_dst) {
continue;
}
dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE][j/nwarps]);
dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE_GGUF][j/nwarps]);
}
}
}
Expand All @@ -115,7 +115,7 @@ static __device__ __forceinline__ void mul_mat_q(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q4_0, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_0, 2)
#endif
mul_mat_q4_0(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -140,7 +140,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -165,7 +165,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q4_1, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_1, 2)
#endif
mul_mat_q4_1(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -190,7 +190,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -215,7 +215,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q5_0, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_0, 2)
#endif
mul_mat_q5_0(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -240,7 +240,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -265,7 +265,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q5_1, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_1, 2)
#endif
mul_mat_q5_1(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -289,7 +289,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -314,7 +314,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q8_0, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q8_0, 2)
#endif
mul_mat_q8_0(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -338,7 +338,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -363,7 +363,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q2_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q2_K, 2)
#endif
mul_mat_q2_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -387,7 +387,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -412,7 +412,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q3_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q3_K, 2)
#endif
mul_mat_q3_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -438,7 +438,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -463,7 +463,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q4_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_K, 2)
#endif
mul_mat_q4_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -487,7 +487,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -512,7 +512,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q5_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_K, 2)
#endif
mul_mat_q5_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -537,7 +537,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand All @@ -562,7 +562,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(

template <bool need_check> static __global__ void
#if defined(USE_ROCM)
__launch_bounds__(WARP_SIZE*NWARPS_Q6_K, 2)
__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q6_K, 2)
#endif
mul_mat_q6_K(
const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst,
Expand All @@ -586,7 +586,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y;
const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1);

if (nrows_x % mmq_y == 0) {
const bool need_check = false;
Expand Down
4 changes: 2 additions & 2 deletions csrc/quantization/gguf/mmvq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *

// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) {
tmp += VLLM_SHFL_XOR_SYNC(tmp, mask);
}

if (threadIdx.x == 0) {
Expand Down
Loading

0 comments on commit 7c25fe4

Please sign in to comment.