10
10
#include " quant/qdq_6.cuh"
11
11
#include " quant/qdq_8.cuh"
12
12
13
- #define BLOCK_KN_SIZE 128
14
- #define BLOCK_M_SIZE_MAX 8
15
- #define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32 )
13
+ #define GPTQ_BLOCK_KN_SIZE 128
14
+ #define GPTQ_BLOCK_M_SIZE_MAX 8
15
+ #define GPTQ_MAX_GROUPS_IN_BLOCK (GPTQ_BLOCK_KN_SIZE / 32 )
16
+
17
+ #define EXL2_BLOCK_KN_SIZE 64
18
+ #define EXL2_BLOCK_M_SIZE_MAX 8
19
+ #define EXL2_MAX_GROUPS_IN_BLOCK (EXL2_BLOCK_KN_SIZE / 32 )
20
+
16
21
#define CLEAR_N_SIZE 256
17
22
18
23
#include " q_gemm_kernel.cuh"
19
24
#include " q_gemm_kernel_gptq.cuh"
20
25
21
- #include " compat_gemm.cuh"
22
-
23
26
void gemm_half_q_half_cuda_part
24
27
(
25
28
const half* a,
@@ -29,20 +32,23 @@ void gemm_half_q_half_cuda_part
29
32
int size_n,
30
33
int size_k,
31
34
int m_count,
32
- bool clear
35
+ bool clear,
36
+ const half* r_weights,
37
+ int r_weights_stride,
38
+ bool mul_r_weights
33
39
)
34
40
{
35
41
if (!b->is_gptq )
36
42
{
37
43
dim3 blockDim , gridDim ;
38
- blockDim .x = BLOCK_KN_SIZE ;
44
+ blockDim .x = EXL2_BLOCK_KN_SIZE ;
39
45
blockDim .y = 1 ;
40
46
blockDim .z = 1 ;
41
- gridDim .x = DIVIDE (size_n, BLOCK_KN_SIZE * 4 );
47
+ gridDim .x = DIVIDE (size_n, EXL2_BLOCK_KN_SIZE * 4 );
42
48
gridDim .y = DIVIDE (size_m, m_count);
43
- gridDim .z = DIVIDE (size_k, BLOCK_KN_SIZE );
49
+ gridDim .z = DIVIDE (size_k, EXL2_BLOCK_KN_SIZE );
44
50
45
- fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel (true , m_count );
51
+ fp_gemm_half_q_half_kernel kernel = pick_gemm_half_q_half_kernel (m_count, r_weights != NULL , mul_r_weights );
46
52
47
53
kernel<<<gridDim , blockDim >>>
48
54
(
@@ -55,32 +61,35 @@ void gemm_half_q_half_cuda_part
55
61
size_n,
56
62
size_k,
57
63
b->groups ,
58
- b->groupsize ,
64
+ b->cuda_q_group_map ,
59
65
b->cuda_q_perm ,
60
66
b->rows_8 ,
61
67
b->rows_6 ,
62
68
b->rows_5 ,
63
69
b->rows_4 ,
64
70
b->rows_3 ,
65
71
b->rows_2 ,
66
- clear
72
+ clear,
73
+ r_weights,
74
+ r_weights_stride
67
75
);
68
76
}
69
77
else
70
78
{
71
79
dim3 blockDim , gridDim ;
72
- blockDim .x = BLOCK_KN_SIZE ;
80
+ blockDim .x = GPTQ_BLOCK_KN_SIZE ;
73
81
blockDim .y = 1 ;
74
82
blockDim .z = 1 ;
75
- gridDim .x = DIVIDE (size_n, BLOCK_KN_SIZE * 4 );
83
+ gridDim .x = DIVIDE (size_n, GPTQ_BLOCK_KN_SIZE * 4 );
76
84
gridDim .y = DIVIDE (size_m, m_count);
77
- gridDim .z = DIVIDE (size_k, BLOCK_KN_SIZE );
85
+ gridDim .z = DIVIDE (size_k, GPTQ_BLOCK_KN_SIZE );
78
86
79
- fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel (true , m_count );
87
+ fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel (m_count, r_weights != NULL , mul_r_weights );
80
88
81
- // DBGX((uint64_t) b->cuda_q_perm);
82
- // DBGI(b->rows_4);
83
- // DBGI(b->height);
89
+ // DBGX((uint64_t) r_weights);
90
+ // if (r_weights)
91
+ // print_global_mem(r_weights, 1, 1, 1);
92
+ // DBGI(r_weights_stride);
84
93
85
94
kernel<<<gridDim , blockDim >>>
86
95
(
@@ -93,10 +102,12 @@ void gemm_half_q_half_cuda_part
93
102
size_n,
94
103
size_k,
95
104
b->groups ,
96
- b->groupsize ,
105
+ b->gptq_groupsize ,
97
106
b->cuda_q_perm ,
98
107
b->rows_4 ,
99
- clear
108
+ clear,
109
+ r_weights,
110
+ r_weights_stride
100
111
);
101
112
}
102
113
}
@@ -112,13 +123,14 @@ void gemm_half_q_half_cuda
112
123
int size_k,
113
124
bool clear,
114
125
half* temp_dq,
115
- bool force_cuda
126
+ bool force_cuda,
127
+ const half* r_weights,
128
+ const int r_weights_stride,
129
+ bool mul_r_weights
116
130
)
117
131
{
118
132
if (size_m > MAX_Q_GEMM_ROWS && !force_cuda)
119
133
{
120
- // printf("cublas\n");
121
-
122
134
// Reconstruct FP16 matrix, then cuBLAS
123
135
124
136
if (!temp_dq) temp_dq = b->temp_dq ;
@@ -139,12 +151,12 @@ void gemm_half_q_half_cuda
139
151
// const float alpha = 1.0f;
140
152
// const float beta = clear ? 0.0f : 1.0f;
141
153
// cublasSgemmEx(cublas_handle,
142
- // CUBLAS_OP_N,
143
- // CUBLAS_OP_N,
144
- // size_n, size_m, size_k,
145
- // &alpha, temp_dq, CUDA_R_16F, size_n,
146
- // a, CUDA_R_16F, size_k,
147
- // &beta, c, CUDA_R_16F, size_n);
154
+ // CUBLAS_OP_N,
155
+ // CUBLAS_OP_N,
156
+ // size_n, size_m, size_k,
157
+ // &alpha, temp_dq, CUDA_R_16F, size_n,
158
+ // a, CUDA_R_16F, size_k,
159
+ // &beta, c, CUDA_R_16F, size_n);
148
160
149
161
// const float alpha = 1.0f;
150
162
// const float beta = clear ? 0.0f : 1.0f;
@@ -158,24 +170,21 @@ void gemm_half_q_half_cuda
158
170
}
159
171
else
160
172
{
161
- // printf("cuda\n");
162
-
163
173
// Quantized matmul
164
174
165
- // if (clear) clear_tensor_cuda(c, size_m, size_n);
166
-
167
- int max_chunks = size_m / BLOCK_M_SIZE_MAX;
168
- int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
175
+ int block_m_size_max = b->is_gptq ? GPTQ_BLOCK_M_SIZE_MAX : EXL2_BLOCK_M_SIZE_MAX;
176
+ int max_chunks = size_m / block_m_size_max;
177
+ int last_chunk = max_chunks * block_m_size_max;
169
178
int last_chunk_size = size_m - last_chunk;
170
179
171
180
if (max_chunks)
172
181
{
173
- gemm_half_q_half_cuda_part (a, b, c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX , clear);
182
+ gemm_half_q_half_cuda_part (a, b, c, last_chunk, size_n, size_k, block_m_size_max , clear, r_weights, r_weights_stride, mul_r_weights );
174
183
}
175
184
176
185
if (last_chunk_size)
177
186
{
178
- gemm_half_q_half_cuda_part (a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear);
187
+ gemm_half_q_half_cuda_part (a + last_chunk * size_k, b, c + last_chunk * size_n, last_chunk_size, size_n, size_k, last_chunk_size, clear, r_weights, r_weights_stride, mul_r_weights );
179
188
}
180
189
}
181
190
}
@@ -201,11 +210,10 @@ void clear_tensor_cuda
201
210
int size_n
202
211
)
203
212
{
204
- return ;
205
- dim3 blockDim , gridDim ;
206
- blockDim .x = CLEAR_N_SIZE;
207
- blockDim .y = 1 ;
208
- gridDim .x = DIVIDE (size_n / 8 , CLEAR_N_SIZE);
209
- gridDim .y = size_m;
210
- clear_kernel<<<gridDim , blockDim >>> (c, size_m, size_n);
213
+ // dim3 blockDim, gridDim;
214
+ // blockDim.x = CLEAR_N_SIZE;
215
+ // blockDim.y = 1;
216
+ // gridDim.x = DIVIDE(size_n / 8, CLEAR_N_SIZE);
217
+ // gridDim.y = size_m;
218
+ // clear_kernel<<<gridDim, blockDim>>>(c, size_m, size_n);
211
219
}
0 commit comments