Skip to content

Commit 564199b

Browse files
feat: update exllamav2 kernels (#1370)
Co-authored-by: Nicolas Patry <[email protected]>
1 parent 987c959 commit 564199b

File tree

17 files changed

+525
-255
lines changed

17 files changed

+525
-255
lines changed

server/exllamav2_kernels/exllamav2_kernels/config.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define _config_h
33

44
#define MAX_Q_GEMM_ROWS 50
5+
#define MAX_Q_GEMM_WEIGHTS 4 // must be <= MAX_Q_GEMM_ROWS
56

67
#define QMODE_2BIT 1
78
#define QMODE_3BIT 1
@@ -10,4 +11,5 @@
1011
#define QMODE_6BIT 0
1112
#define QMODE_8BIT 0
1213

14+
1315
#endif

server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cu

Lines changed: 53 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,19 @@
1010
#include "quant/qdq_6.cuh"
1111
#include "quant/qdq_8.cuh"
1212

13-
#define BLOCK_KN_SIZE 128
14-
#define BLOCK_M_SIZE_MAX 8
15-
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
13+
#define GPTQ_BLOCK_KN_SIZE 128
14+
#define GPTQ_BLOCK_M_SIZE_MAX 8
15+
#define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32)
16+
17+
#define EXL2_BLOCK_KN_SIZE 64
18+
#define EXL2_BLOCK_M_SIZE_MAX 8
19+
#define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32)
20+
1621
#define CLEAR_N_SIZE 256
1722

1823
#include "q_gemm_kernel.cuh"
1924
#include "q_gemm_kernel_gptq.cuh"
2025

21-
#include "compat_gemm.cuh"
22-
2326
void gemm_half_q_half_cuda_part
2427
(
2528
const half* a,
@@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part
2932
int size_n,
3033
int size_k,
3134
int m_count,
32-
bool clear
35+
bool clear,
36+
const half* r_weights,
37+
int r_weights_stride,
38+
bool mul_r_weights
3339
)
3440
{
3541
if (!b->is_gptq)
3642
{
3743
dim3 blockDim, gridDim;
38-
blockDim.x = BLOCK_KN_SIZE;
44+
blockDim.x = EXL2_BLOCK_KN_SIZE;
3945
blockDim.y = 1;
4046
blockDim.z = 1;
41-
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
47+
gridDim.x = DIVIDE(size_n, EXL2_BLOCK_KN_SIZE * 4);
4248
gridDim.y = DIVIDE(size_m, m_count);
43-
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
49+
gridDim.z = DIVIDE(size_k, EXL2_BLOCK_KN_SIZE);
4450

45-
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(true, m_count);
51+
fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel(m_count, r_weights != NULL, mul_r_weights);
4652

4753
kernel<<<gridDim, blockDim>>>
4854
(
@@ -55,32 +61,35 @@ void gemm_half_q_half_cuda_part
5561
size_n,
5662
size_k,
5763
b->groups,
58-
b->groupsize,
64+
b->cuda_q_group_map,
5965
b->cuda_q_perm,
6066
b->rows_8,
6167
b->rows_6,
6268
b->rows_5,
6369
b->rows_4,
6470
b->rows_3,
6571
b->rows_2,
66-
clear
72+
clear,
73+
r_weights,
74+
r_weights_stride
6775
);
6876
}
6977
else
7078
{
7179
dim3 blockDim, gridDim;
72-
blockDim.x = BLOCK_KN_SIZE;
80+
blockDim.x = GPTQ_BLOCK_KN_SIZE;
7381
blockDim.y = 1;
7482
blockDim.z = 1;
75-
gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
83+
gridDim.x = DIVIDE(size_n, GPTQ_BLOCK_KN_SIZE * 4);
7684
gridDim.y = DIVIDE(size_m, m_count);
77-
gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
85+
gridDim.z = DIVIDE(size_k, GPTQ_BLOCK_KN_SIZE);
7886

79-
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
87+
fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(m_count, r_weights != NULL, mul_r_weights);
8088

81-
// DBGX((uint64_t) b->cuda_q_perm);
82-
// DBGI(b->rows_4);
83-
// DBGI(b->height);
89+
// DBGX((uint64_t) r_weights);
90+
// if (r_weights)
91+
// print_global_mem(r_weights, 1, 1, 1);
92+
// DBGI(r_weights_stride);
8493

8594
kernel<<<gridDim, blockDim>>>
8695
(
@@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part
93102
size_n,
94103
size_k,
95104
b->groups,
96-
b->groupsize,
105+
b->gptq_groupsize,
97106
b->cuda_q_perm,
98107
b->rows_4,
99-
clear
108+
clear,
109+
r_weights,
110+
r_weights_stride
100111
);
101112
}
102113
}
@@ -112,13 +123,14 @@ void gemm_half_q_half_cuda
112123
int size_k,
113124
bool clear,
114125
half* temp_dq,
115-
bool force_cuda
126+
bool force_cuda,
127+
const half* r_weights,
128+
const int r_weights_stride,
129+
bool mul_r_weights
116130
)
117131
{
118132
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
119133
{
120-
//printf("cublas\n");
121-
122134
// Reconstruct FP16 matrix, then cuBLAS
123135

124136
if (!temp_dq) temp_dq = b->temp_dq;
@@ -139,12 +151,12 @@ void gemm_half_q_half_cuda
139151
//const float alpha = 1.0f;
140152
//const float beta = clear ? 0.0f : 1.0f;
141153
//cublasSgemmEx(cublas_handle,
142-
// CUBLAS_OP_N,
143-
// CUBLAS_OP_N,
144-
// size_n, size_m, size_k,
145-
// &alpha, temp_dq, CUDA_R_16F, size_n,
146-
// a, CUDA_R_16F, size_k,
147-
// &beta, c, CUDA_R_16F, size_n);
154+
// CUBLAS_OP_N,
155+
// CUBLAS_OP_N,
156+
// size_n, size_m, size_k,
157+
// &alpha, temp_dq, CUDA_R_16F, size_n,
158+
// a, CUDA_R_16F, size_k,
159+
// &beta, c, CUDA_R_16F, size_n);
148160

149161
//const float alpha = 1.0f;
150162
//const float beta = clear ? 0.0f : 1.0f;
@@ -158,24 +170,21 @@ void gemm_half_q_half_cuda
158170
}
159171
else
160172
{
161-
//printf("cuda\n");
162-
163173
// Quantized matmul
164174

165-
//if (clear) clear_tensor_cuda(c, size_m, size_n);
166-
167-
int max_chunks = size_m / BLOCK_M_SIZE_MAX;
168-
int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
175+
int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX;
176+
int max_chunks = size_m / block_m_size_max;
177+
int last_chunk = max_chunks * block_m_size_max;
169178
int last_chunk_size = size_m - last_chunk;
170179

171180
if (max_chunks)
172181
{
173-
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX, clear);
182+
gemm_half_q_half_cuda_part(a, b, c, last_chunk, size_n, size_k, block_m_size_max, clear, r_weights, r_weights_stride, mul_r_weights);
174183
}
175184

176185
if (last_chunk_size)
177186
{
178-
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
187+
gemm_half_q_half_cuda_part(a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights);
179188
}
180189
}
181190
}
@@ -201,11 +210,10 @@ void clear_tensor_cuda
201210
int size_n
202211
)
203212
{
204-
return;
205-
dim3 blockDim, gridDim;
206-
blockDim.x = CLEAR_N_SIZE;
207-
blockDim.y = 1;
208-
gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
209-
gridDim.y = size_m;
210-
clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
213+
// dim3 blockDim, gridDim;
214+
// blockDim.x = CLEAR_N_SIZE;
215+
// blockDim.y = 1;
216+
// gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
217+
// gridDim.y = size_m;
218+
// clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
211219
}

server/exllamav2_kernels/exllamav2_kernels/cuda/q_gemm.cuh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ void gemm_half_q_half_cuda
2020
int size_k,
2121
bool clear = false,
2222
half* reconstruct = NULL,
23-
bool force_cuda = false
23+
bool force_cuda = false,
24+
const half* r_weights = NULL,
25+
const int r_weights_stride = 0,
26+
bool mul_r_weights = false
2427
);
2528

2629
void clear_tensor_cuda

0 commit comments

Comments
 (0)