-
Notifications
You must be signed in to change notification settings - Fork 0
/
simpleTensorCoreGEMM.cu
310 lines (244 loc) · 11.5 KB
/
simpleTensorCoreGEMM.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
// learn how to use CUDA's WMMA API to perform Tensorcore matrix multiplication.
/* Copyright (c) 1993-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of NVIDIA CORPORATION nor the names of its
* contributors may be used to endorse or promote products derived
* from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
* EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
* PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
* CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
* EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
* PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
* OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
#include <stdio.h>
#include <curand.h>
#include <cublas_v2.h>
// Define some error checking macros.
#define cudaErrCheck(stat) { cudaErrCheck_((stat), __FILE__, __LINE__); }
void cudaErrCheck_(cudaError_t stat, const char *file, int line) {
if (stat != cudaSuccess) {
fprintf(stderr, "CUDA Error: %s %s %d\n", cudaGetErrorString(stat), file, line);
}
}
#define cublasErrCheck(stat) { cublasErrCheck_((stat), __FILE__, __LINE__); }
void cublasErrCheck_(cublasStatus_t stat, const char *file, int line) {
if (stat != CUBLAS_STATUS_SUCCESS) {
fprintf(stderr, "cuBLAS Error: %d %s %d\n", stat, file, line);
}
}
#define curandErrCheck(stat) { curandErrCheck_((stat), __FILE__, __LINE__); }
void curandErrCheck_(curandStatus_t stat, const char *file, int line) {
if (stat != CURAND_STATUS_SUCCESS) {
fprintf(stderr, "cuRand Error: %d %s %d\n", stat, file, line);
}
}
#include <mma.h> // define tensorcores' API
using namespace nvcuda;
// Must be multiples of 16 for wmma code to work
#define MATRIX_M 16384
#define MATRIX_N 16384
#define MATRIX_K 16384
// The only dimensions currently supported by WMMA
const int WMMA_M = 16;
const int WMMA_N = 16;
const int WMMA_K = 16;
// Performs an MxNxK GEMM (C=alpha*A*B + beta*C) assuming:
// 1) Matrices are packed in memory.
// 2) M, N and K are multiples of 16.
// 3) Neither A nor B are transposed.
// Note: This is NOT a high performance example but is for demonstration purposes only
// For a high performance code please use the GEMM provided in cuBLAS.
__global__ void wmma_example(half *a, half *b, float *c, int M, int N, int K, float alpha, float beta) {
// Leading dimensions. Packed with no transpositions.
int lda = M;
int ldb = K;
int ldc = M;
// Tile using a 2D grid
int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / warpSize;
int warpN = (blockIdx.y * blockDim.y + threadIdx.y);
// Declare the fragments
wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> a_frag;
wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> acc_frag;
wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, float> c_frag;
wmma::fill_fragment(acc_frag, 0.0f);
// Loop over k
for (int i = 0; i < K; i += WMMA_K) {
int aRow = warpM * WMMA_M;
int aCol = i;
int bRow = i;
int bCol = warpN * WMMA_N;
// Bounds checking
if (aRow < M && aCol < K && bRow < K && bCol < N) {
// Load the inputs
wmma::load_matrix_sync(a_frag, a + aRow + aCol * lda, lda);
wmma::load_matrix_sync(b_frag, b + bRow + bCol * ldb, ldb);
// Perform the matrix multiplication
wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag);
}
}
// Load in the current value of c, scale it by beta, and add this our result scaled by alpha
int cRow = warpM * WMMA_M;
int cCol = warpN * WMMA_N;
if (cRow < M && cCol < N) {
wmma::load_matrix_sync(c_frag, c + cRow + cCol * ldc, ldc, wmma::mem_col_major);
#pragma unroll
for(int i=0; i < c_frag.num_elements; i++) {
c_frag.x[i] = acc_frag.x[i];
}
// Store the output
wmma::store_matrix_sync(c + cRow + cCol * ldc, c_frag, ldc, wmma::mem_col_major);
}
}
__global__ void convertFp32ToFp16 (half *out, float *in, int n) {
int idx = blockDim.x * blockIdx.x + threadIdx.x;
if (idx < n) {
out[idx] = in[idx];
}
}
int main(int argc, char* argv[]) {
float *a_fp32;
float *b_fp32;
half *a_fp16;
half *b_fp16;
float *c;
float *c_cublas;
float *c_wmma;
float *c_host_cublas;
float *c_host_wmma;
curandGenerator_t gen;
cublasHandle_t cublasHandle;
cudaEvent_t startWMMA;
cudaEvent_t stopWMMA;
cudaEvent_t startcublas;
cudaEvent_t stopcublas;
cudaErrCheck(cudaEventCreate(&startWMMA));
cudaErrCheck(cudaEventCreate(&stopWMMA));
cudaErrCheck(cudaEventCreate(&startcublas));
cudaErrCheck(cudaEventCreate(&stopcublas));
cublasErrCheck(cublasCreate(&cublasHandle));
// Use tensor cores
cublasErrCheck(cublasSetMathMode(cublasHandle, CUBLAS_TENSOR_OP_MATH));
cudaErrCheck(cudaMalloc((void**)&a_fp32, MATRIX_M * MATRIX_K * sizeof(float)));
cudaErrCheck(cudaMalloc((void**)&b_fp32, MATRIX_K * MATRIX_N * sizeof(float)));
cudaErrCheck(cudaMalloc((void**)&a_fp16, MATRIX_M * MATRIX_K * sizeof(half)));
cudaErrCheck(cudaMalloc((void**)&b_fp16, MATRIX_K * MATRIX_N * sizeof(half)));
cudaErrCheck(cudaMalloc((void**)&c, MATRIX_M * MATRIX_N * sizeof(float)));
cudaErrCheck(cudaMalloc((void**)&c_cublas, MATRIX_M * MATRIX_N * sizeof(float)));
cudaErrCheck(cudaMalloc((void**)&c_wmma, MATRIX_M * MATRIX_N * sizeof(float)));
c_host_cublas = (float*)malloc(MATRIX_M * MATRIX_N * sizeof(float));
c_host_wmma = (float*)malloc(MATRIX_M * MATRIX_N * sizeof(float));
curandErrCheck(curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT));
curandErrCheck(curandSetPseudoRandomGeneratorSeed(gen, 1337ULL));
curandErrCheck(curandGenerateUniform(gen, a_fp32, MATRIX_M * MATRIX_K));
curandErrCheck(curandGenerateUniform(gen, b_fp32, MATRIX_K * MATRIX_N));
// curand doesn't currently support fp16 so we generate in fp32 and convert to fp16.
convertFp32ToFp16 <<< (MATRIX_M * MATRIX_K + 255) / 256, 256 >>> (a_fp16, a_fp32, MATRIX_M * MATRIX_K);
convertFp32ToFp16 <<< (MATRIX_K * MATRIX_N + 255) / 256, 256 >>> (b_fp16, b_fp32, MATRIX_K * MATRIX_N);
curandErrCheck(curandGenerateUniform(gen, c, MATRIX_M * MATRIX_N));
curandErrCheck(curandDestroyGenerator(gen));
cudaErrCheck(cudaMemcpy(c_cublas, c, MATRIX_M * MATRIX_N * sizeof(float), cudaMemcpyDeviceToDevice));
cudaErrCheck(cudaMemcpy(c_wmma, c, MATRIX_M * MATRIX_N * sizeof(float), cudaMemcpyDeviceToDevice));
float alpha = 1.0f;
float beta = 0.0f;
printf("\nM = %d, N = %d, K = %d. alpha = %f, beta = %f\n\n", MATRIX_M, MATRIX_N, MATRIX_K, alpha, beta);
// First: using WMMA
dim3 gridDim;
dim3 blockDim;
// blockDim.x must be a multple of warpSize
// 128x4 means we have 16 warps and a block computes a 64x64 output tile
blockDim.x = 128;
blockDim.y = 4;
gridDim.x = (MATRIX_M + (WMMA_M * blockDim.x / 32 - 1)) / (WMMA_M * blockDim.x / 32);
gridDim.y = (MATRIX_N + WMMA_N * blockDim.y - 1) / (WMMA_N * blockDim.y);
printf("Running with wmma...\n");
cudaErrCheck(cudaEventRecord(startWMMA));
wmma_example <<< gridDim, blockDim >>> (a_fp16, b_fp16, c_wmma, MATRIX_M, MATRIX_N, MATRIX_K, alpha, beta);
cudaErrCheck(cudaEventRecord(stopWMMA));
cudaErrCheck(cudaEventSynchronize(stopWMMA));
// Now using cuBLAS
printf("Running with cuBLAS...\n");
// Warm up cuBLAS run starts
cublasErrCheck(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N,
MATRIX_M, MATRIX_N, MATRIX_K,
&alpha,
a_fp16, CUDA_R_16F, MATRIX_M,
b_fp16, CUDA_R_16F, MATRIX_K,
&beta,
c_cublas, CUDA_R_32F, MATRIX_M,
CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Warm up cuBLAS run ends
// reset the c_cublas buffer
cudaErrCheck(cudaMemcpy(c_cublas, c, MATRIX_M * MATRIX_N * sizeof(float), cudaMemcpyDeviceToDevice));
cudaErrCheck(cudaEventRecord(startcublas));
cublasErrCheck(cublasGemmEx(cublasHandle, CUBLAS_OP_N, CUBLAS_OP_N,
MATRIX_M, MATRIX_N, MATRIX_K,
&alpha,
a_fp16, CUDA_R_16F, MATRIX_M,
b_fp16, CUDA_R_16F, MATRIX_K,
&beta,
c_cublas, CUDA_R_32F, MATRIX_M,
CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
cudaErrCheck(cudaEventRecord(stopcublas));
cudaErrCheck(cudaEventSynchronize(stopcublas));
// Error checking
printf("\nChecking results...\n");
cudaErrCheck(cudaMemcpy(c_host_wmma, c_wmma, MATRIX_M * MATRIX_N * sizeof(float), cudaMemcpyDeviceToHost));
cudaErrCheck(cudaMemcpy(c_host_cublas, c_cublas, MATRIX_M * MATRIX_N * sizeof(float), cudaMemcpyDeviceToHost));
// 0.01% relative tolerance. 1e-5 absolute tolerance.
int errors = 0;
for (int i = 0; i < MATRIX_M * MATRIX_N; i++) {
float v1 = c_host_wmma[i];
float v2 = c_host_cublas[i];
float diff = fabs(v1 - v2);
float relative_err = diff / v2;
float eps = 1e-4;
if ((relative_err >= eps)) {
errors++;
if (errors < 10) printf("%f %f\n", v1, v2);
}
}
if (errors > 0) {
printf("WMMA does not agree with cuBLAS! %d errors!\n", errors);
}
else {
printf("Results verified: cublas and WMMA agree.\n\n");
float wmmaTime;
float cublasTime;
cudaErrCheck(cudaEventElapsedTime(&wmmaTime, startWMMA, stopWMMA));
cudaErrCheck(cudaEventElapsedTime(&cublasTime, startcublas, stopcublas));
printf("wmma took %fms\n", wmmaTime);
printf("cublas took %fms\n", cublasTime);
printf("\nFor a faster code using wmma you should check out the cudaTensorCoreGemm sample in the CUDA Toolkit.\nThis code was written as a demo only!\n\n");
}
cudaErrCheck(cudaEventDestroy(startWMMA));
cudaErrCheck(cudaEventDestroy(stopWMMA));
cudaErrCheck(cudaEventDestroy(startcublas));
cudaErrCheck(cudaEventDestroy(stopcublas));
cudaErrCheck(cudaFree(a_fp32));
cudaErrCheck(cudaFree(b_fp32));
cudaErrCheck(cudaFree(a_fp16));
cudaErrCheck(cudaFree(b_fp16));
cudaErrCheck(cudaFree(c));
cudaErrCheck(cudaFree(c_cublas));
cudaErrCheck(cudaFree(c_wmma));
free(c_host_cublas);
free(c_host_wmma);
cudaErrCheck(cudaDeviceReset());
return 0;
}