-
Notifications
You must be signed in to change notification settings - Fork 0
/
kernel.cuh
349 lines (281 loc) · 14.6 KB
/
kernel.cuh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
#include <cuda.h>
#include <mma.h>
using namespace nvcuda;
#define OFFSET(row, col, ld) ((row) * (ld) + (col))
__global__ void GEMM_sharedmem_wmma(
half * __restrict__ a, half * __restrict__ b, half * __restrict__ c,
const int M, const int N, const int K) {
const int BM = 128;
const int BN = 256;
const int BK = 32;
int bx = blockIdx.x;
int by = blockIdx.y;
int tid = threadIdx.x;
int wid = tid >> 5;
const int APAD = 8;
const int BPAD = 8;
extern __shared__ half smem[];
half *s_a = smem;
half *s_b = smem + 2 * BM * (BK + APAD);
int s_a_db_offset = BM * (BK + APAD);
int s_b_db_offset = BK * (BN + BPAD);
wmma::fragment<wmma::matrix_a, 16, 16, 16, half, wmma::row_major> frag_a[2][4];
wmma::fragment<wmma::matrix_b, 16, 16, 16, half, wmma::row_major> frag_b[2][4];
wmma::fragment<wmma::accumulator, 16, 16, 16, half> frag_c[4][4];
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::fill_fragment(frag_c[i][j], 0.0);
}
}
int load_a_smem_m = (tid >> 2) << 1;
int load_a_smem_k = (tid & 3) << 3;
int load_b_smem_k = (tid >> 5) << 2;
int load_b_smem_n = (tid & 31) << 3;
int s_a_base_addr = __cvta_generic_to_shared(s_a);
int s_b_base_addr = __cvta_generic_to_shared(s_b);
int load_a_smem_addr_0 = s_a_base_addr + OFFSET(load_a_smem_m, load_a_smem_k, BK + APAD) * sizeof(half);
int load_a_smem_addr_1 = load_a_smem_addr_0 + (BK + APAD) * sizeof(half);
int load_b_smem_addr_0 = s_b_base_addr + OFFSET(load_b_smem_k, load_b_smem_n, BN + BPAD) * sizeof(half);
int load_b_smem_addr_1 = load_b_smem_addr_0 + (BN + BPAD) * sizeof(half);
int load_b_smem_addr_2 = load_b_smem_addr_0 + 2 * (BN + BPAD) * sizeof(half);
int load_b_smem_addr_3 = load_b_smem_addr_0 + 3 * (BN + BPAD) * sizeof(half);
int load_a_gmem_m = by * BM + load_a_smem_m;
int load_b_gmem_n = bx * BN + load_b_smem_n;
int load_a_gmem_addr = OFFSET(load_a_gmem_m, load_a_smem_k, K);
int load_b_gmem_addr = OFFSET(load_b_smem_k, load_b_gmem_n, N);
int comp_c_frag_m = wid & 1;
int comp_c_frag_n = wid >> 1;
{
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_a_smem_addr_0), "l"(&a[load_a_gmem_addr ]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_a_smem_addr_1), "l"(&a[load_a_gmem_addr + K]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_0), "l"(&b[load_b_gmem_addr ]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_1), "l"(&b[load_b_gmem_addr + N]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_2), "l"(&b[load_b_gmem_addr + 2 * N]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_3), "l"(&b[load_b_gmem_addr + 3 * N]));
asm ("cp.async.commit_group;\n" ::);
asm ("cp.async.wait_group 0;\n" ::);
__syncthreads();
}
for (int bk = 1; bk < K / BK; bk++) {
int smem_sel = (bk & 1) ^ 1;
int smem_sel_next = ((bk - 1) & 1) ^ 1;
load_a_gmem_addr += BK;
load_b_gmem_addr += BK * N;
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_a_smem_addr_0 + smem_sel_next * s_a_db_offset * (int)sizeof(half)), "l"(&a[load_a_gmem_addr ]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_a_smem_addr_1 + smem_sel_next * s_a_db_offset * (int)sizeof(half)), "l"(&a[load_a_gmem_addr + K]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_0 + smem_sel_next * s_b_db_offset * (int)sizeof(half)), "l"(&b[load_b_gmem_addr ]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_1 + smem_sel_next * s_b_db_offset * (int)sizeof(half)), "l"(&b[load_b_gmem_addr + N]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_2 + smem_sel_next * s_b_db_offset * (int)sizeof(half)), "l"(&b[load_b_gmem_addr + 2 * N]));
asm ("cp.async.ca.shared.global [%0], [%1], 16;\n" :
: "r"(load_b_smem_addr_3 + smem_sel_next * s_b_db_offset * (int)sizeof(half)), "l"(&b[load_b_gmem_addr + 3 * N]));
wmma::load_matrix_sync(frag_a[0][0], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 ) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][1], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 16) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][2], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 32) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][3], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 48) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[1][0], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 ) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][1], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 16) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][2], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 32) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][3], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 48) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_b[0][0], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 ], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][1], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 + 16], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][2], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 + 32], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][3], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 + 48], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][0], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 ], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][1], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 + 16], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][2], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 + 32], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][3], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 + 48], BN + BPAD);
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::mma_sync(frag_c[i][j], frag_a[0][i], frag_b[0][j], frag_c[i][j]);
wmma::mma_sync(frag_c[i][j], frag_a[1][i], frag_b[1][j], frag_c[i][j]);
}
}
asm ("cp.async.commit_group;\n" ::);
asm ("cp.async.wait_group 0;\n" ::);
__syncthreads();
}
int smem_sel = ((K / BK) & 1) ^ 1;
wmma::load_matrix_sync(frag_a[0][0], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 ) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][1], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 16) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][2], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 32) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[0][3], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 48) * (BK + APAD) + 0], BK + APAD);
wmma::load_matrix_sync(frag_a[1][0], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 ) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][1], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 16) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][2], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 32) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_a[1][3], &s_a[smem_sel * s_a_db_offset + (comp_c_frag_m * 64 + 48) * (BK + APAD) + 16], BK + APAD);
wmma::load_matrix_sync(frag_b[0][0], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 ], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][1], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 + 16], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][2], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 + 32], BN + BPAD);
wmma::load_matrix_sync(frag_b[0][3], &s_b[smem_sel * s_b_db_offset + comp_c_frag_n * 64 + 48], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][0], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 ], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][1], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 + 16], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][2], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 + 32], BN + BPAD);
wmma::load_matrix_sync(frag_b[1][3], &s_b[smem_sel * s_b_db_offset + 16 * (BN + BPAD) + comp_c_frag_n * 64 + 48], BN + BPAD);
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::mma_sync(frag_c[i][j], frag_a[0][i], frag_b[0][j], frag_c[i][j]);
wmma::mma_sync(frag_c[i][j], frag_a[1][i], frag_b[1][j], frag_c[i][j]);
}
}
int store_c_gmem_m = by * BM + comp_c_frag_m * 64;
int store_c_gmem_n = bx * BN + comp_c_frag_n * 64;
int store_c_gmem_addr = OFFSET(store_c_gmem_m, store_c_gmem_n, N);
#pragma unroll
for (int i = 0; i < 4; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
wmma::store_matrix_sync(&c[store_c_gmem_addr + i * 16 * N + j * 16], frag_c[i][j], N, wmma::mem_row_major);
}
}
}
__global__ void transpose(half *a, half *b, int M, int N) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
for(int i = tid; i < M * N; i += blockDim.x * gridDim.x) {
int x = i / N;
int y = i % N;
b[y * M + x] = a[x * N + y];
}
__syncthreads();
}
#define FINAL_MASK 0xffffffff
template <typename T>
__inline__ __device__
T warpReduceSum(T val)
{
for(int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}
template <typename T>
__inline__ __device__
T blockReduceSum(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if(lane == 0)
shared[wid] = val;
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : (T)0.0f;
val = warpReduceSum(val);
return val;
}
__device__ half max(half x, half y) {
if(x < y) return x;
else return y;
}
template <typename T>
__inline__ __device__
T warpReduceMax(T val)
{
for(int mask = 16; mask > 0; mask >>= 1)
val = max(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32));
return val;
}
/* Calculate the maximum of all elements in a block */
template <typename T>
__inline__ __device__
T blockReduceMax(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; // in-warp idx
int wid = threadIdx.x >> 5; // warp idx
val = warpReduceMax(val); // get maxx in each warp
if(lane == 0) // record in-warp maxx by warp Idx
shared[wid] = val;
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5 )) ? shared[lane] : (half)-6e4;
val = warpReduceMax(val);
return val;
}
__global__ void softmax_kernel(half* qk_buf_, const half* attr_mask,
const int batch_size, const int head_num,
const int seq_len, const half scaler)
{
int batch_id = blockIdx.x / head_num / seq_len;
int seq_id = blockIdx.x % seq_len;
int qk_offset = blockIdx.x * seq_len;
int mask_offset = batch_id * seq_len * seq_len + seq_id * seq_len;
__shared__ half s_sum, s_max;
half qk = threadIdx.x < seq_len ? (half)qk_buf_[threadIdx.x + qk_offset] : (half)0.0f;
half mask_val = threadIdx.x < seq_len ? (half)attr_mask[threadIdx.x + mask_offset] : (half)0.0f;
mask_val = ((half)1.0f - mask_val) * (half)-10000.0f;
half tmp = threadIdx.x < seq_len ? (half)(qk * (half)scaler + mask_val) : (half)-6e4;
half max_val = blockReduceMax<half>(tmp);
if(threadIdx.x == 0)
s_max = max_val;
__syncthreads();
half qk_tmp = threadIdx.x < seq_len ? __expf((half)(tmp - s_max)) : 0.0f;
half sum_val = blockReduceSum<half>(qk_tmp);
if(threadIdx.x == 0)
{
s_sum = sum_val;
}
__syncthreads();
if(threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = (half)(qk_tmp / s_sum);
}
__global__ void gemm_baseline(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M, size_t N, size_t K) {
// you can change everything in this function, including the function signature
// You can create a CUDA (.cu) file containing a class that inherits from the abstract base class GEMM.
int row = blockIdx.x * blockDim.x + threadIdx.x;
int col = blockIdx.y * blockDim.y + threadIdx.y;
half sum = 0;
if( col < M && row < N) {
for(int i =0; i<K; i++ ) {
sum += A[row*N + i] * B[i*N + col];
}
C[row*N + col] = sum;
}
}
__global__ void addmask_kernel(half* qk_buf_, const half* attr_mask,
const int batch_size, const int head_num,
const int seq_len, const half scaler)
{
int batch_id = blockIdx.x / head_num / seq_len;
int seq_id = blockIdx.x % seq_len;
int qk_offset = blockIdx.x * seq_len;
int mask_offset = batch_id * seq_len * seq_len + seq_id * seq_len;
half qk = threadIdx.x < seq_len ? (half)qk_buf_[threadIdx.x + qk_offset] : (half)0.0f;
half mask_val = threadIdx.x < seq_len ? (half)attr_mask[threadIdx.x + mask_offset] : (half)0.0f;
mask_val = ((half)1.0f - mask_val) * (half)-10000.0f;
qk_buf_[threadIdx.x + qk_offset] = threadIdx.x < seq_len ? (half)(qk * (half)scaler + mask_val) : (half)-6e4;
}
__global__ void softcal_kernel(half* qk_buf_, const half* attr_mask,
const int batch_size, const int head_num,
const int seq_len, const half scaler)
{
int qk_offset = blockIdx.x * seq_len;
half s_sum = 0, s_max = -6e4;
half tmp = threadIdx.x < seq_len ? qk_buf_[threadIdx.x + qk_offset] : (half)-6e4;
for(int i = 0; i < seq_len; i++) {
s_sum += qk_buf_[i+ qk_offset];
s_max = max(s_max, qk_buf_[i+ qk_offset]);
}
__syncthreads();
for(int i = 0; i < seq_len; i++) {
tmp = threadIdx.x < seq_len ? __expf((half)(tmp - s_max)) : 0.0f;
s_sum += tmp;
}
__syncthreads();
if(threadIdx.x < seq_len)
qk_buf_[threadIdx.x + qk_offset] = (half)(tmp / s_sum);
}