2
2
#include < torch/python.h>
3
3
#include < cuda.h>
4
4
#include < cuda_runtime.h>
5
+ #include < cuda_fp16.h>
5
6
6
7
template <typename scalar_t >
7
8
__global__ void VecQuant3MatMulKernel (
@@ -14,8 +15,18 @@ __global__ void VecQuant3MatMulKernel(
14
15
int width
15
16
);
16
17
17
- const int BLOCKWIDTH = 1024 ;
18
- const int BLOCKHEIGHT = 96 ;
18
+ __global__ void VecQuant3MatMulKernelFaster (
19
+ const half2* __restrict__ vec,
20
+ const int * __restrict__ mat,
21
+ float * __restrict__ mul,
22
+ const float * __restrict__ scales,
23
+ const float * __restrict__ zeros,
24
+ int height,
25
+ int width
26
+ );
27
+
28
+ const int BLOCKWIDTH = 256 ;
29
+ const int BLOCKHEIGHT = 24 ;
19
30
20
31
void vecquant3matmul_cuda (
21
32
torch::Tensor vec,
@@ -29,7 +40,7 @@ void vecquant3matmul_cuda(
29
40
30
41
dim3 blocks (
31
42
(height + BLOCKHEIGHT - 1 ) / BLOCKHEIGHT,
32
- (width + BLOCKWIDTH - 1 ) / BLOCKWIDTH
43
+ (width + BLOCKWIDTH - 1 ) / BLOCKWIDTH
33
44
);
34
45
dim3 threads (BLOCKWIDTH);
35
46
@@ -44,6 +55,32 @@ void vecquant3matmul_cuda(
44
55
);
45
56
}
46
57
58
+ void vecquant3matmul_faster_cuda (
59
+ torch::Tensor vec,
60
+ torch::Tensor mat,
61
+ torch::Tensor mul,
62
+ torch::Tensor scales,
63
+ torch::Tensor zeros
64
+ ) {
65
+ int height = mat.size (0 );
66
+ int width = mat.size (1 );
67
+
68
+ dim3 blocks (
69
+ (height + BLOCKHEIGHT - 1 ) / BLOCKHEIGHT,
70
+ (width + BLOCKWIDTH - 1 ) / BLOCKWIDTH
71
+ );
72
+ dim3 threads (BLOCKWIDTH);
73
+
74
+ VecQuant3MatMulKernelFaster<<<blocks, threads>>> (
75
+ (half2*) vec.data_ptr (),
76
+ mat.data_ptr <int >(),
77
+ mul.data_ptr <float >(),
78
+ scales.data_ptr <float >(),
79
+ zeros.data_ptr <float >(),
80
+ height, width
81
+ );
82
+ }
83
+
47
84
__device__ inline unsigned int as_unsigned (int i) {
48
85
return *reinterpret_cast <unsigned int *>(&i);
49
86
}
@@ -126,3 +163,82 @@ __global__ void VecQuant3MatMulKernel(
126
163
127
164
atomicAdd (&mul[col], res);
128
165
}
166
+
167
+ __global__ void VecQuant3MatMulKernelFaster (
168
+ const half2* __restrict__ vec,
169
+ const int * __restrict__ mat,
170
+ float * __restrict__ mul,
171
+ const float * __restrict__ scales,
172
+ const float * __restrict__ zeros,
173
+ int height,
174
+ int width
175
+ ) {
176
+ const int blockwidth2 = BLOCKWIDTH / 2 ;
177
+
178
+ int row = BLOCKHEIGHT * blockIdx .x ;
179
+ int col = BLOCKWIDTH * blockIdx .y + threadIdx .x ;
180
+
181
+ __shared__ half2 blockvec[blockwidth2];
182
+ if (threadIdx .x < blockwidth2)
183
+ blockvec[threadIdx .x ] = vec[(row / BLOCKHEIGHT) * blockwidth2 + threadIdx .x ];
184
+
185
+ __shared__ half2 deq2[64 ][32 ];
186
+ int val = threadIdx .x / 32 ;
187
+ int off = threadIdx .x % 32 ;
188
+ for (; val < 64 ; val += BLOCKWIDTH / 32 ) {
189
+ deq2[val][off] = __halves2half2 (
190
+ __int2half_rn (val & 0x7 ), __int2half_rn (val >> 3 )
191
+ );
192
+ }
193
+
194
+ half2 scale = __float2half2_rn (scales[col]);
195
+ half2 zero = __float2half2_rn (-zeros[col]);
196
+
197
+ int i = width * row + col;
198
+ int k = 0 ;
199
+
200
+ float res = 0 ;
201
+ half2 res2;
202
+
203
+ unsigned int tmp1;
204
+ unsigned int tmp2;
205
+ unsigned int tmp;
206
+
207
+ __syncthreads ();
208
+
209
+ while (k < blockwidth2) {
210
+ res2 = {};
211
+ tmp1 = as_unsigned (mat[i]);
212
+ res2 = __hfma2 (__hfma2 (deq2[(tmp1 >> 0 ) & 0x3f ][off], scale, zero), blockvec[k + 0 ], res2);
213
+ res2 = __hfma2 (__hfma2 (deq2[(tmp1 >> 6 ) & 0x3f ][off], scale, zero), blockvec[k + 1 ], res2);
214
+ res2 = __hfma2 (__hfma2 (deq2[(tmp1 >> 12 ) & 0x3f ][off], scale, zero), blockvec[k + 2 ], res2);
215
+ res2 = __hfma2 (__hfma2 (deq2[(tmp1 >> 18 ) & 0x3f ][off], scale, zero), blockvec[k + 3 ], res2);
216
+ res2 = __hfma2 (__hfma2 (deq2[(tmp1 >> 24 ) & 0x3f ][off], scale, zero), blockvec[k + 4 ], res2);
217
+ i += width;
218
+ tmp2 = as_unsigned (mat[i]);
219
+ tmp = (tmp1 >> 30 ) | ((tmp2 << 2 ) & 0x3c );
220
+ res2 = __hfma2 (__hfma2 (deq2[tmp][off], scale, zero), blockvec[k + 5 ], res2);
221
+ tmp2 >>= 4 ;
222
+ k += 6 ;
223
+ res2 = __hfma2 (__hfma2 (deq2[(tmp2 >> 0 ) & 0x3f ][off], scale, zero), blockvec[k + 0 ], res2);
224
+ res2 = __hfma2 (__hfma2 (deq2[(tmp2 >> 6 ) & 0x3f ][off], scale, zero), blockvec[k + 1 ], res2);
225
+ res2 = __hfma2 (__hfma2 (deq2[(tmp2 >> 12 ) & 0x3f ][off], scale, zero), blockvec[k + 2 ], res2);
226
+ res2 = __hfma2 (__hfma2 (deq2[(tmp2 >> 18 ) & 0x3f ][off], scale, zero), blockvec[k + 3 ], res2);
227
+ i += width;
228
+ tmp1 = as_unsigned (mat[i]);
229
+ tmp = (tmp2 >> 24 ) | ((tmp1 << 4 ) & 0x30 );
230
+ res2 = __hfma2 (__hfma2 (deq2[tmp][off], scale, zero), blockvec[k + 4 ], res2);
231
+ tmp1 >>= 2 ;
232
+ k += 5 ;
233
+ res2 = __hfma2 (__hfma2 (deq2[(tmp1 >> 0 ) & 0x3f ][off], scale, zero), blockvec[k + 0 ], res2);
234
+ res2 = __hfma2 (__hfma2 (deq2[(tmp1 >> 6 ) & 0x3f ][off], scale, zero), blockvec[k + 1 ], res2);
235
+ res2 = __hfma2 (__hfma2 (deq2[(tmp1 >> 12 ) & 0x3f ][off], scale, zero), blockvec[k + 2 ], res2);
236
+ res2 = __hfma2 (__hfma2 (deq2[(tmp1 >> 18 ) & 0x3f ][off], scale, zero), blockvec[k + 3 ], res2);
237
+ res2 = __hfma2 (__hfma2 (deq2[(tmp1 >> 24 ) & 0x3f ][off], scale, zero), blockvec[k + 4 ], res2);
238
+ i += width;
239
+ k += 5 ;
240
+ res += __half2float (res2.x ) + __half2float (res2.y );
241
+ }
242
+
243
+ atomicAdd (&mul[col], res);
244
+ }
0 commit comments