Skip to content

Commit 54d35a8

Browse files
committed
Implemented faster kernels optimized for A100 GPUs
1 parent 669e515 commit 54d35a8

File tree

5 files changed

+165
-16
lines changed

5 files changed

+165
-16
lines changed

opt.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def forward(self, inp, **kwargs):
229229
def opt_pack3(model, quantizers):
230230
layers = find_layers(model)
231231
layers = {n: layers[n] for n in quantizers}
232-
make_quant3(model, quantizers)
232+
make_quant3(model, quantizers, faster=args.faster_kernel)
233233
qlayers = find_layers(model, [Quant3Linear])
234234
print('Packing ...')
235235
for name in qlayers:
@@ -258,7 +258,7 @@ def noop(*args, **kwargs):
258258
for name in ['model.decoder.project_out', 'model.decoder.project_in', 'lm_head']:
259259
if name in layers:
260260
del layers[name]
261-
make_quant3(model, layers)
261+
make_quant3(model, layers, faster=args.faster_kernel)
262262

263263
print('Loading model ...')
264264
model.load_state_dict(torch.load(checkpoint))
@@ -416,7 +416,11 @@ def sync():
416416
)
417417
parser.add_argument(
418418
'--new-eval', action='store_true',
419-
help='Whether to use the new PTB and C4 eval'
419+
help='Whether to use the new PTB and C4 eval.'
420+
)
421+
parser.add_argument(
422+
'--faster-kernel', action='store_true',
423+
help='Whether to use the new faster kernel for benchmarking.'
420424
)
421425

422426
args = parser.parse_args()

quant.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,15 @@ def ready(self):
136136
# Assumes layer is perfectly divisible into 1024 * 1024 blocks
137137
class Quant3Linear(nn.Module):
138138

139-
def __init__(self, infeatures, outfeatures):
139+
def __init__(self, infeatures, outfeatures, faster=False):
140140
super().__init__()
141141
self.register_buffer('zeros', torch.zeros((outfeatures, 1)))
142142
self.register_buffer('scales', torch.zeros((outfeatures, 1)))
143143
self.register_buffer('bias', torch.zeros(outfeatures))
144144
self.register_buffer(
145145
'qweight', torch.zeros((infeatures // 32 * 3, outfeatures), dtype=torch.int)
146146
)
147+
self.faster = faster
147148

148149
def pack(self, linear, scales, zeros):
149150
self.zeros = zeros * scales
@@ -187,21 +188,25 @@ def forward(self, x):
187188
y = self.bias.clone()
188189
outshape[-1] = self.bias.numel()
189190
dtype = x.dtype
190-
x = x.float()
191-
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.zeros)
191+
if self.faster:
192+
x = x.half()
193+
quant_cuda.vecquant3matmul_faster(x, self.qweight, y, self.scales, self.zeros)
194+
else:
195+
x = x.float()
196+
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.zeros)
192197
y = y.to(dtype)
193198
return y.reshape(outshape)
194199
raise ValueError('Only supports a single token currently.')
195200

196-
def make_quant3(module, names, name=''):
201+
def make_quant3(module, names, name='', faster=False):
197202
if isinstance(module, Quant3Linear):
198203
return
199204
for attr in dir(module):
200205
tmp = getattr(module, attr)
201206
name1 = name + '.' + attr if name != '' else attr
202207
if name1 in names:
203208
setattr(
204-
module, attr, Quant3Linear(tmp.in_features, tmp.out_features)
209+
module, attr, Quant3Linear(tmp.in_features, tmp.out_features, faster=faster)
205210
)
206211
for name1, child in module.named_children():
207-
make_quant3(child, names, name + '.' + name1 if name != '' else name1)
212+
make_quant3(child, names, name + '.' + name1 if name != '' else name1, faster=faster)

quant_cuda.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@
55
void vecquant3matmul_cuda(
66
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
77
torch::Tensor scales, torch::Tensor zeros
8+
);
9+
10+
void vecquant3matmul_faster_cuda(
11+
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
12+
torch::Tensor scales, torch::Tensor zeros
813
);
914

1015
void vecquant3matmul(
@@ -15,6 +20,15 @@ void vecquant3matmul(
1520
vecquant3matmul_cuda(vec, mat, mul, scales, zeros);
1621
}
1722

23+
void vecquant3matmul_faster(
24+
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
25+
torch::Tensor scales, torch::Tensor zeros
26+
) {
27+
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
28+
vecquant3matmul_faster_cuda(vec, mat, mul, scales, zeros);
29+
}
30+
1831
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
1932
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
33+
m.def("vecquant3matmul_faster", &vecquant3matmul_faster, "Vector 3-bit Quantized Matrix Multiplication (CUDA), faster version");
2034
}

quant_cuda_kernel.cu

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <torch/python.h>
33
#include <cuda.h>
44
#include <cuda_runtime.h>
5+
#include <cuda_fp16.h>
56

67
template <typename scalar_t>
78
__global__ void VecQuant3MatMulKernel(
@@ -14,8 +15,18 @@ __global__ void VecQuant3MatMulKernel(
1415
int width
1516
);
1617

17-
const int BLOCKWIDTH = 1024;
18-
const int BLOCKHEIGHT = 96;
18+
__global__ void VecQuant3MatMulKernelFaster(
19+
const half2* __restrict__ vec,
20+
const int* __restrict__ mat,
21+
float* __restrict__ mul,
22+
const float* __restrict__ scales,
23+
const float* __restrict__ zeros,
24+
int height,
25+
int width
26+
);
27+
28+
const int BLOCKWIDTH = 256;
29+
const int BLOCKHEIGHT = 24;
1930

2031
void vecquant3matmul_cuda(
2132
torch::Tensor vec,
@@ -29,7 +40,7 @@ void vecquant3matmul_cuda(
2940

3041
dim3 blocks(
3142
(height + BLOCKHEIGHT - 1) / BLOCKHEIGHT,
32-
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
43+
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
3344
);
3445
dim3 threads(BLOCKWIDTH);
3546

@@ -44,6 +55,32 @@ void vecquant3matmul_cuda(
4455
);
4556
}
4657

58+
void vecquant3matmul_faster_cuda(
59+
torch::Tensor vec,
60+
torch::Tensor mat,
61+
torch::Tensor mul,
62+
torch::Tensor scales,
63+
torch::Tensor zeros
64+
) {
65+
int height = mat.size(0);
66+
int width = mat.size(1);
67+
68+
dim3 blocks(
69+
(height + BLOCKHEIGHT - 1) / BLOCKHEIGHT,
70+
(width + BLOCKWIDTH - 1) / BLOCKWIDTH
71+
);
72+
dim3 threads(BLOCKWIDTH);
73+
74+
VecQuant3MatMulKernelFaster<<<blocks, threads>>>(
75+
(half2*) vec.data_ptr(),
76+
mat.data_ptr<int>(),
77+
mul.data_ptr<float>(),
78+
scales.data_ptr<float>(),
79+
zeros.data_ptr<float>(),
80+
height, width
81+
);
82+
}
83+
4784
__device__ inline unsigned int as_unsigned(int i) {
4885
return *reinterpret_cast<unsigned int*>(&i);
4986
}
@@ -126,3 +163,82 @@ __global__ void VecQuant3MatMulKernel(
126163

127164
atomicAdd(&mul[col], res);
128165
}
166+
167+
__global__ void VecQuant3MatMulKernelFaster(
168+
const half2* __restrict__ vec,
169+
const int* __restrict__ mat,
170+
float* __restrict__ mul,
171+
const float* __restrict__ scales,
172+
const float* __restrict__ zeros,
173+
int height,
174+
int width
175+
) {
176+
const int blockwidth2 = BLOCKWIDTH / 2;
177+
178+
int row = BLOCKHEIGHT * blockIdx.x;
179+
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
180+
181+
__shared__ half2 blockvec[blockwidth2];
182+
if (threadIdx.x < blockwidth2)
183+
blockvec[threadIdx.x] = vec[(row / BLOCKHEIGHT) * blockwidth2 + threadIdx.x];
184+
185+
__shared__ half2 deq2[64][32];
186+
int val = threadIdx.x / 32;
187+
int off = threadIdx.x % 32;
188+
for (; val < 64; val += BLOCKWIDTH / 32) {
189+
deq2[val][off] = __halves2half2(
190+
__int2half_rn(val & 0x7), __int2half_rn(val >> 3)
191+
);
192+
}
193+
194+
half2 scale = __float2half2_rn(scales[col]);
195+
half2 zero = __float2half2_rn(-zeros[col]);
196+
197+
int i = width * row + col;
198+
int k = 0;
199+
200+
float res = 0;
201+
half2 res2;
202+
203+
unsigned int tmp1;
204+
unsigned int tmp2;
205+
unsigned int tmp;
206+
207+
__syncthreads();
208+
209+
while (k < blockwidth2) {
210+
res2 = {};
211+
tmp1 = as_unsigned(mat[i]);
212+
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
213+
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
214+
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2);
215+
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2);
216+
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2);
217+
i += width;
218+
tmp2 = as_unsigned(mat[i]);
219+
tmp = (tmp1 >> 30) | ((tmp2 << 2) & 0x3c);
220+
res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 5], res2);
221+
tmp2 >>= 4;
222+
k += 6;
223+
res2 = __hfma2(__hfma2(deq2[(tmp2 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
224+
res2 = __hfma2(__hfma2(deq2[(tmp2 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
225+
res2 = __hfma2(__hfma2(deq2[(tmp2 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2);
226+
res2 = __hfma2(__hfma2(deq2[(tmp2 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2);
227+
i += width;
228+
tmp1 = as_unsigned(mat[i]);
229+
tmp = (tmp2 >> 24) | ((tmp1 << 4) & 0x30);
230+
res2 = __hfma2(__hfma2(deq2[tmp][off], scale, zero), blockvec[k + 4], res2);
231+
tmp1 >>= 2;
232+
k += 5;
233+
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 0) & 0x3f][off], scale, zero), blockvec[k + 0], res2);
234+
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 6) & 0x3f][off], scale, zero), blockvec[k + 1], res2);
235+
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 12) & 0x3f][off], scale, zero), blockvec[k + 2], res2);
236+
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 18) & 0x3f][off], scale, zero), blockvec[k + 3], res2);
237+
res2 = __hfma2(__hfma2(deq2[(tmp1 >> 24) & 0x3f][off], scale, zero), blockvec[k + 4], res2);
238+
i += width;
239+
k += 5;
240+
res += __half2float(res2.x) + __half2float(res2.y);
241+
}
242+
243+
atomicAdd(&mul[col], res);
244+
}

test_kernel.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
DEV = torch.device('cuda:0')
1212

13-
M = 12288
14-
N = 12288 * 4
13+
M = 12288 * 4
14+
N = 12288
1515

1616
DTYPE = torch.half
1717
mat = torch.randn((M, N), device=DEV, dtype=DTYPE)
@@ -43,6 +43,14 @@
4343
torch.cuda.synchronize()
4444
print('3bit:', (time.time() - tick) / COUNT)
4545

46+
COUNT = 1000
47+
import time
48+
tick = time.time()
49+
for _ in range(COUNT):
50+
quant_cuda.vecquant3matmul_faster(vec, mat, mul, scales, zeros)
51+
torch.cuda.synchronize()
52+
print('3bit:', (time.time() - tick) / COUNT, '(faster)')
53+
4654
print('Verifiying kernel correctness ...')
4755

4856
M = 4 * 4096
@@ -66,5 +74,7 @@
6674
layer = layer.to(DEV)
6775

6876
with torch.no_grad():
69-
print('Simu:', qlayer(vec))
70-
print('Kern:', layer.to(DEV)(vec))
77+
print('Simu:', layer.to(DEV)(vec))
78+
print('Kern:', qlayer(vec))
79+
qlayer.faster = True
80+
print('Kern:', qlayer(vec.half()), '(faster)')

0 commit comments

Comments
 (0)