forked from stanford-cs149/cs149gpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodule.cpp
481 lines (398 loc) · 19 KB
/
module.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <iostream>
#include <time.h>
#include <sys/time.h>
#include <vector>
#include <immintrin.h>
#include <cmath>
// Uncomment for ISPC
//#include "module_ispc.h"
//using namespace ispc;
// ------------------------------------ //
// WARM-UP: ACCESSING TENSORS //
// ------------------------------------ //
// Step #1: Understand Read/Write Accessors for a 2D Tensor
inline float twoDimRead(std::vector<float> &tensor, int &x, int &y, const int &sizeX) {
// Note that sizeX is the size of a Row, not the number of rows
return tensor[x * (sizeX)+ y];
}
inline void twoDimWrite(std::vector<float> &tensor, int &x, int &y, const int &sizeX, float &val) {
tensor[x * (sizeX) + y] = val;
}
// Step #2: Implement Read/Write Accessors for a 4D Tensor
inline float fourDimRead(std::vector<float> &tensor, int &x, int &y, int &z, int &b,
const int &sizeX, const int &sizeY, const int &sizeZ) {
return tensor[x * (sizeX * sizeY * sizeZ) + y * (sizeY * sizeZ) + z * (sizeZ) + b];
}
inline void fourDimWrite(std::vector<float> &tensor, int &x, int &y, int &z, int &b,
const int &sizeX, const int &sizeY, const int &sizeZ, float &val) {
tensor[x * (sizeX * sizeY * sizeZ) + y * (sizeY * sizeZ) + z * (sizeZ) + b] = val;
}
// DO NOT EDIT THIS FUNCTION //
std::vector<float> formatTensor(torch::Tensor tensor) {
tensor = tensor.flatten();
tensor = tensor.contiguous();
std::vector<float> vec(tensor.data_ptr<float>(), tensor.data_ptr<float>() + tensor.numel());
return vec;
}
/* Programming Your Attention Modules.
*
* You are given Q, K, and V Tensors as inputs that are formatted as vectors. We have also created O and QK^t Tensors
* that are formatted as vectors. After you have implemented your accessors in the Warm-Up you should be able to
* read/write to these tensors via the read/write functions above.
*
* You are also given 4 integers as parameters: B, H, N, d:
*
* B (Batch Size) - The number of samples for your attention layer. Think of it this way - if I asked my dnn
* a question and it output 5 different answers it had a batch size of 5. These samples are independent of each
* other and thus can be parallelized.
*
* H (Number of Heads) - Each head runs on its own set of Q, K, V matrices. This effectively allows each head
* to operate the same attention algorithm, but each with each head using different hyperparameters. These
* allow each head to have their own definition of what relevance is when looking at a token. These heads
* can operate independently of one another and thus can be parallized.
*
* N (Sequence Length) - The number of tokens. You may think of this as the number of words in a sample.
*
* d (Embedding Dimensionality) - The number of features each token encodes per attention head. Let's
* say I encoded a word using the follow (length, number of vowels, has a capital letters). The
* emvedded dimensionaliy would be 3.
* */
// ---------------------------------------------------------- //
// PART 1: NAIVE ATTENTION //
// ---------------------------------------------------------- //
torch::Tensor myNaiveAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor QK_tTensor,
int B, int H, int N, int d){
// Q, K, V are passed in with Shape: (B, H, N, d)
//QK^t Intermediate Tensor has Shape (N, N)
//Make O Tensor with Shape (B, H, N, d)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
//Format O, Q, K, and V tensors into 4D vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
//Format QK_t Tensor into a 2D vector.
std::vector<float> QK_t = formatTensor(QK_tTensor);
/* Here is an example of how to read/write 0's to Q (B, H, N, d) using the 4D accessors
//loop over Batch Size
for (int b = 0; b < B; b++) {
//loop over Heads
for (int h = 0; h < H; h++) {
//loop over Sequence Length
for (int i = 0; i < N; i++) {
//loop over Embedding Dimensionality
for (int j = 0; j < d; j++) {
float val = fourDimRead(Q, b, h, i, j, H, N, d);
val = 0.0;
fourDimWrite(Q, b, h, i, j, H, N, d, val);
}
}
}
}
*/
/* Here is an example of how to read/write 0's to QK_t (N, N) using the 2D accessors
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
float val = twoDimRead(QK_t, i, j, N);
val = 0.0;
twoDimWrite(QK_t, i, j, N, val);
}
}
*/
// -------- YOUR CODE HERE -------- //
for (int b = 0; b < B; b++) {
for (int h = 0; h < H; h++) {
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
float val = 0.0;
for (int k = 0; k < d; k++) {
float q_val = fourDimRead(Q, b, h, i, k, H, N, d);
float k_val = fourDimRead(K, b, h, j, k, H, N, d);
val += q_val * k_val;
}
twoDimWrite(QK_t, i, j, N, val);
}
}
for (int i = 0; i < N; i++) {
float sum = 0.0;
for (int j = 0; j < N; j++) {
float val = twoDimRead(QK_t, i, j, N);
sum += exp(val);
}
for (int j = 0; j < N; j++) {
float val = twoDimRead(QK_t, i, j, N);
float res = (float)(exp(val)/sum);
twoDimWrite(QK_t, i, j, N, res);
}
}
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
float attention_score = twoDimRead(QK_t, i, j, N);
for (int k = 0; k < d; k++) {
float val = attention_score * fourDimRead(V, b, h, j, k, H, N, d);
float O_old_val = fourDimRead(O, b, h, i, k, H, N, d);
float O_new_val = val + O_old_val;
fourDimWrite(O, b, h, i, k, H, N, d, O_new_val);
}
}
}
}
}
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
// ---------------------------------------------------------- //
// PART 2: BLOCKED MATRIX MULTIPLY AND UNFUSED SOFTMAX //
// ---------------------------------------------------------- //
torch::Tensor myUnfusedAttentionBlocked(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor QK_tTensor,
int B, int H, int N, int d){
// Q, K, V are passed in with Shape: (B, H, N, d)
//QK^t Intermediate Tensor has Shape (N, N)
//Make O Tensor with Shape (B, H, N, d)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
//Format O, Q, K, and V tensors into 4D vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
//Format QK_t Tensor into a 2D vector.
std::vector<float> QK_t = formatTensor(QK_tTensor);
// -------- YOUR CODE HERE -------- //
for (int b = 0; b < B; b++) {
for (int h = 0; h < H; h++) {
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
float val = 0.0;
for (int k = 0; k < d; k++) {
float q_val = fourDimRead(Q, b, h, i, k, H, N, d);
float k_val = fourDimRead(K, b, h, j, k, H, N, d);
val += q_val * k_val;
}
twoDimWrite(QK_t, i, j, N, val);
}
}
for (int i = 0; i < N; i++) {
float sum = 0.0;
for (int j = 0; j < N; j++) {
float val = twoDimRead(QK_t, i, j, N);
sum += exp(val);
}
for (int j = 0; j < N; j++) {
float val = twoDimRead(QK_t, i, j, N);
float res = (float)(exp(val)/sum);
twoDimWrite(QK_t, i, j, N, res);
}
}
for (int i = 0; i < N; i++) {
for (int j = 0; j < N; j++) {
float attention_score = twoDimRead(QK_t, i, j, N);
for (int k = 0; k < d; k++) {
float val = attention_score * fourDimRead(V, b, h, j, k, H, N, d);
float O_old_val = fourDimRead(O, b, h, i, k, H, N, d);
float O_new_val = val + O_old_val;
fourDimWrite(O, b, h, i, k, H, N, d, O_new_val);
}
}
}
}
}
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
// ---------------------------------------------------------- //
// PART 3: FUSED ATTENTION //
// ---------------------------------------------------------- //
torch::Tensor myFusedAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor, torch::Tensor temp,
int B, int H, int N, int d){
// Q, K, V are passed in with Shape: (B, H, N, d)
//Make O Tensor with Shape (B, H, N, d)
//and O Row Tensor with Shape (N)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
at::Tensor ORowTensor = at::zeros({N}, at::kFloat);
//Format Y, Q, K, and V tensors into 4D vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
//Format ORow Tensor into a 1D vector
// You can simply access this as ORow[i]
std::vector<float> ORow = formatTensor(ORowTensor);
// -------- YOUR CODE HERE -------- //
// We give you a template of the first three loops for your convenience
//loop over batch
#pragma omp parallel for collapse(3)
for (int b = 0; b < B; b++){
//loop over heads
for (int h = 0; h < H; h++){
for (int i = 0; i < N ; i++){
// YRow is moved inside so each OpenMP thread gets a local copy.
at::Tensor ORowTensor = temp.index({torch::indexing::Slice(omp_get_thread_num(), torch::indexing::None)});
std::vector<float> ORow = formatTensor(ORowTensor);
//YOUR CODE HERE
for (int j = 0; j < N; j++) {
float val = 0.0;
for (int k = 0; k < d; k++) {
float q_val = fourDimRead(Q, b, h, i, k, H, N, d);
float k_val = fourDimRead(K, b, h, j, k, H, N, d);
val += q_val * k_val;
}
ORow[j] = val;
}
float sum = 0.0;
for (int j = 0; j < N; j++) {
float val = ORow[j];
sum += exp(val);
}
for (int j = 0; j < N; j++) {
float val = ORow[j];
float res = (float)(exp(val)/sum);
ORow[j] = res;
}
for (int j = 0; j < N; j++) {
float attention_score = ORow[j];
for (int k = 0; k < d; k++) {
float val = attention_score * fourDimRead(V, b, h, j, k, H, N, d);
float O_old_val = fourDimRead(O, b, h, i, k, H, N, d);
float O_new_val = val + O_old_val;
fourDimWrite(O, b, h, i, k, H, N, d, O_new_val);
}
}
}
}
}
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
// ---------------------------------------------------------- //
// PART 4: FLASH ATTENTION //
// ---------------------------------------------------------- //
torch::Tensor myFlashAttention(torch::Tensor QTensor, torch::Tensor KTensor, torch::Tensor VTensor,
torch::Tensor QiTensor, torch::Tensor KjTensor, torch::Tensor VjTensor,
torch::Tensor SijTensor, torch::Tensor PijTensor, torch::Tensor PVTensor,
torch::Tensor OiTensor, torch::Tensor LTensor, torch::Tensor LiTensor,
torch::Tensor LijTensor, torch::Tensor LnewTensor, int Bc, int Br,
int B, int H, int N, int d) {
// Q, K, V are passed in with Shape: (B, H, N, d)
// Sij, Pij are passed in with Shape: (Br, Bc)
// Kj, Vj are passed in with Shape: (Bc, d)
// Qi, Oi, and PV are passed in with Shape: (Br, d)
// L in passed in with Shape: (N)
// Li, Lij, and Lnew are passed in with shape (Br)
//Make O Tensor with Shape (B, H, N, d)
at::Tensor OTensor = at::zeros({B, H, N, d}, at::kFloat);
//Format All Tensors into Vectors
std::vector<float> O = formatTensor(OTensor);
std::vector<float> Q = formatTensor(QTensor);
std::vector<float> K = formatTensor(KTensor);
std::vector<float> V = formatTensor(VTensor);
std::vector<float> Sij = formatTensor(SijTensor);
std::vector<float> Pij = formatTensor(PijTensor);
std::vector<float> Kj = formatTensor(KjTensor);
std::vector<float> Vj = formatTensor(VjTensor);
std::vector<float> Qi = formatTensor(QiTensor);
std::vector<float> Oi = formatTensor(OiTensor);
std::vector<float> l = formatTensor(LTensor);
std::vector<float> PV = formatTensor(PVTensor);
std::vector<float> li = formatTensor(LiTensor);
std::vector<float> lij = formatTensor(LijTensor);
std::vector<float> lnew = formatTensor(LnewTensor);
// -------- YOUR CODE HERE -------- //
int TC = N % Bc == 0 ? N / Bc : N / Bc + 1;
int TR = N % Br == 0 ? N / Br : N / Br + 1;
for (int b = 0; b < B; b++) {
for (int h = 0; h < H; h++) {
for (int i = 0; i < N; i += Br) {
for (int k = 0; k < Br; k++) {
for (int j = 0; j < d; j++) {
int idx = i + k;
float q_val = fourDimRead(Q, b, h, idx, j, H, N, d);
twoDimWrite(Qi, k, j, d, q_val);
float o_val = fourDimRead(O, b, h, idx, j, H, N, d);
twoDimWrite(Oi, k, j, d, o_val);
}
}
for (int r = 0; r < N; r += Bc) {
for (int k = 0; k < Bc; k++) {
for (int j = 0; j < d; j++) {
int idx = r + k;
float k_val = fourDimRead(K, b, h, idx, j, H, N, d);
twoDimWrite(Kj, k, j, d, k_val);
float v_val = fourDimRead(V, b, h, idx, j, H, N, d);
twoDimWrite(Vj, k, j, d, v_val);
}
}
for (int k = 0; k < Br; k++) {
li[k] = l[i + k];
}
for (int k = 0; k < Br; k++) {
for (int j = 0; j < Bc; j++) {
float val = 0.0;
for (int p = 0; p < d; p++) {
float q_val = twoDimRead(Qi, k, p, d);
float k_val = twoDimRead(Kj, j, p, d);
val += q_val * k_val;
}
twoDimWrite(Sij, k, j, Bc, val);
}
}
for (int k = 0; k < Br; k++) {
for (int j = 0; j < Bc; j++) {
float val = exp(twoDimRead(Sij, k, j, Bc));
twoDimWrite(Pij, k, j, Bc, val);
}
}
for (int k = 0; k < Br; k++) {
float val = 0.0;
for (int j = 0; j < Bc; j++) {
val += twoDimRead(Sij, k, j, Bc);
}
lij[k] = val;
}
for (int k = 0; k < Br; k++) {
lnew[k] = lij[k] + li[k];
}
for (int k = 0; k < Br; k++) {
for (int c = 0; c < Bc; c++) {
float p_val = twoDimRead(Pij, k, c, Bc);
for (int j = 0; j < d; j++) {
float val1 = li[k] * twoDimRead(Oi, k, j, d);
float val2 = p_val * twoDimRead(Vj, k, j, d);
float val3 = (float)((val1 + val2) / lnew[k]);
// twoDimWrite(Oi, k, j, d, val3);
}
}
}
for (int k = 0; k < Br; k++) {
l[i + k] = lnew[k];
}
for (int k = 0; k < Br; k++) {
for (int j = 0; j < d; j++) {
float val = twoDimRead(Oi, k, j, d);
int idx = i + k;
float O_old_val = fourDimRead(O, b, h, idx, j, H, N, d);
float O_new_val = val + O_old_val;
fourDimWrite(O, b, h, idx, j, H, N, d, val);
}
}
}
}
}
}
// DO NOT EDIT THIS RETURN STATEMENT //
// It formats your C++ Vector O back into a Tensor of Shape (B, H, N, d) and returns it //
return torch::from_blob(O.data(), {B, H, N, d}, torch::TensorOptions().dtype(torch::kFloat32)).clone();
}
/* DO NOT EDIT THESE BINDINGS */
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("myNaiveAttention", &myNaiveAttention, "Naive Attention");
m.def("myUnfusedAttentionBlocked", &myUnfusedAttentionBlocked, " Blocked Unfused Attention");
m.def("myFusedAttention", &myFusedAttention, "Fused Attention");
m.def("myFlashAttention", &myFlashAttention, "Flash Attention");
m.def("twoDimRead", &twoDimRead, "twoDimRead");
m.def("fourDimRead", &fourDimRead, "fourDimRead");
}