@@ -70,11 +70,11 @@ __global__ void run_test_load(Copy& copy) {
70
70
71
71
copy (s_tile, r_tile);
72
72
73
- #if defined(DEBUG)
73
+ // #if defined(DEBUG)
74
74
if (thread0 ()) {
75
75
r_tile.dump_value ();
76
76
}
77
- #endif
77
+ // #endif
78
78
}
79
79
80
80
template <typename Shared, typename Reg, typename Loader, typename Storer>
@@ -180,71 +180,79 @@ TEST(TestShared2Reg, operand_A) { // load mode for loading operand A in gemm
180
180
cudaDeviceSynchronize ();
181
181
}
182
182
183
- TEST (TestShared2Reg, operand_B) { // load mode for loading operand B in gemm
184
- using Element = __half;
183
+ // TEST(TestShared2Reg, operand_B) { // load mode for loading operand B in gemm
184
+ // using Element = __half;
185
185
186
- using WarpLayout = tl::RowMajor<2 , 2 >;
187
- const int kThreads = tl::get_numel<WarpLayout> * 32 ;
186
+ // using WarpLayout = tl::RowMajor<2, 2>;
187
+ // const int kThreads = tl::get_numel<WarpLayout> * 32;
188
188
189
- // a 32x64 row-major shared tile is equivalent to a 64x32 col-major tile
190
- using Shared = SharedTile<Element, tl::RowMajor<32 , 64 >>;
189
+ // // a 32x64 row-major shared tile is equivalent to a 64x32 col-major tile
190
+ // using Shared = SharedTile<Element, tl::RowMajor<32, 64>>;
191
191
192
- // Each thread accesses 4x2 elements (the shape of `BaseHalfTileRowMajor`)
193
- // within a 16x16 `BaseTile`. These 4x2 elements are accessed 2x2 times
194
- // along each dimension, contributing to the final register tile handled by
195
- // a single thread.
196
- using Reg = RegTile<BaseTileColMajor<Element>, tl::ColMajor<2 , 2 >>;
197
- // In the `ColReuseCont` mode, warps in the same column repeatedly access
198
- // the same data.
199
- using Copy = SharedToRegLoader<Reg, WarpLayout, WarpReuse::kColReuseCont >;
200
- Copy copy;
192
+ // // Each thread accesses 4x2 elements (the shape of
193
+ // `BaseHalfTileRowMajor`)
194
+ // // within a 16x16 `BaseTile`. These 4x2 elements are accessed 2x2 times
195
+ // // along each dimension, contributing to the final register tile handled
196
+ // by
197
+ // // a single thread.
198
+ // using Reg = RegTile<BaseTileColMajor<Element>, tl::ColMajor<2, 2>>;
199
+ // // In the `ColReuseCont` mode, warps in the same column repeatedly access
200
+ // // the same data.
201
+ // using Copy = SharedToRegLoader<Reg, WarpLayout,
202
+ // WarpReuse::kColReuseCont>; Copy copy;
201
203
202
- dim3 dim_grid (1 , 1 , 1 );
203
- dim3 dim_block (kThreads , 1 , 1 );
204
- int shm_size = Shared::kNumel * sizeof (Element);
204
+ // dim3 dim_grid(1, 1, 1);
205
+ // dim3 dim_block(kThreads, 1, 1);
206
+ // int shm_size = Shared::kNumel * sizeof(Element);
205
207
206
- run_test_load<Element, Shared, Reg, Copy>
207
- <<<dim_grid, dim_block, shm_size>>> (copy);
208
- cudaDeviceSynchronize ();
209
- }
208
+ // run_test_load<Element, Shared, Reg, Copy>
209
+ // <<<dim_grid, dim_block, shm_size>>>(copy);
210
+ // cudaDeviceSynchronize();
211
+ // }
210
212
211
- TEST (TestReg2Shared, operand_C_half) {
212
- using Element = __half;
213
+ // TEST(TestReg2Shared, operand_C_half) {
214
+ // using Element = __half;
213
215
214
- using WarpLayout = tl::RowMajor<1 , 1 >;
215
- const int kThreads = tl::get_numel<WarpLayout> * 32 ;
216
+ // using WarpLayout = tl::RowMajor<1, 1>;
217
+ // const int kThreads = tl::get_numel<WarpLayout> * 32;
216
218
217
- using Shared = SharedTile<Element, tl::RowMajor<16 , 16 >>;
218
- using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<1 , 1 >>;
219
+ // // using Shared = SharedTile<Element, tl::RowMajor<16, 16>>;
220
+ // using Shared = SharedTile<Element, tl::RowMajor<16, 64>>;
221
+ // // using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<1, 1>>;
222
+ // using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<1, 4>>;
219
223
220
- using Loader = SharedToRegLoader<Reg, WarpLayout, WarpReuse::kCont >;
221
- Loader loader;
224
+ // using Loader = SharedToRegLoader<Reg, WarpLayout, WarpReuse::kCont>;
225
+ // Loader loader;
222
226
223
- using Storer = RegToSharedStorer<Reg, WarpLayout>;
224
- Storer storer;
227
+ // using Storer = RegToSharedStorer<Reg, WarpLayout>;
228
+ // Storer storer;
225
229
226
- dim3 dim_grid (1 , 1 , 1 );
227
- dim3 dim_block (kThreads , 1 , 1 );
228
- int shm_size = Shared::kNumel * sizeof (Element);
230
+ // dim3 dim_grid(1, 1, 1);
231
+ // dim3 dim_block(kThreads, 1, 1);
232
+ // int shm_size = Shared::kNumel * sizeof(Element);
229
233
230
- run_test_store<Shared, Reg, Loader, Storer>
231
- <<<dim_grid, dim_block, shm_size>>> (loader, storer);
232
- cudaDeviceSynchronize ();
233
- }
234
+ // run_test_store<Shared, Reg, Loader, Storer>
235
+ // <<<dim_grid, dim_block, shm_size>>>(loader, storer);
236
+ // cudaDeviceSynchronize();
237
+ // }
234
238
235
239
TEST (TestShared2Reg, operand_A_swizzle) {
236
240
using Element = __half;
237
241
238
242
using WarpLayout = tl::RowMajor<1 , 1 >;
239
243
const int kThreads = tl::get_numel<WarpLayout> * 32 ;
240
244
241
- const int kRows = 64 ;
242
- const int kCols = 32 ;
245
+ // const int kRows = 64;
246
+ // const int kCols = 32;
247
+
248
+ const int kRows = 16 ;
249
+ const int kCols = 64 ;
243
250
244
251
using SharedLayout = tl::RowMajor<kRows , kCols >;
245
252
const bool kUseSwizzledLayout = true ;
246
253
using Shared = SharedTile<Element, SharedLayout, kUseSwizzledLayout >;
247
- using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<2 , 2 >>;
254
+ // using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<2, 2>>;
255
+ using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<1 , 4 >>;
248
256
249
257
using Copy = SharedToRegLoader<Reg, WarpLayout, WarpReuse::kRowReuseCont >;
250
258
Copy copy;
@@ -258,48 +266,48 @@ TEST(TestShared2Reg, operand_A_swizzle) {
258
266
cudaDeviceSynchronize ();
259
267
}
260
268
261
- TEST (TestReg2Shared, operand_C_float) {
262
- using Element = __half;
263
- using AccType = float ;
264
-
265
- const int kRowRepeats = 4 ;
266
- const int kColRepeats = 8 ;
267
- const int kRows = 16 * kRowRepeats ;
268
- const int kCols = 16 * kColRepeats ;
269
-
270
- const int kWarpPerRow = 2 ;
271
- const int kWarpPerCol = 2 ;
272
- using WarpLayout = tl::RowMajor<kWarpPerRow , kWarpPerCol >;
273
- const int kThreads = tl::get_numel<WarpLayout> * 32 ;
274
-
275
- using SharedHalf = SharedTile<Element, tl::RowMajor<kRows , kCols >>;
276
- using RegHalf = RegTile<
277
- BaseTileRowMajor<Element>,
278
- tl::RowMajor<kRowRepeats / kWarpPerRow , kColRepeats / kWarpPerCol >>;
279
-
280
- using SharedFloat = SharedTile<AccType, tl::RowMajor<kRows , kCols >>;
281
- using RegFloat = RegTile<
282
- BaseTileRowMajor<AccType>,
283
- tl::RowMajor<kRowRepeats / kWarpPerRow , kColRepeats / kWarpPerCol >>;
284
-
285
- using ConvertHalf = compute::RegTileConvert<RegHalf, RegFloat>;
286
- ConvertHalf convert;
287
-
288
- using Loader = SharedToRegLoader<RegHalf, WarpLayout, WarpReuse::kCont >;
289
- Loader loader;
290
-
291
- using Storer = RegToSharedStorer<RegFloat, WarpLayout>;
292
- Storer storer;
293
-
294
- dim3 dim_grid (1 , 1 , 1 );
295
- dim3 dim_block (kThreads , 1 , 1 );
296
- int shm_size = SharedHalf::kNumel * sizeof (Element) +
297
- SharedFloat::kNumel * sizeof (AccType);
298
-
299
- run_test_store_float<SharedHalf, RegHalf, SharedFloat, RegFloat,
300
- ConvertHalf, Loader, Storer>
301
- <<<dim_grid, dim_block, shm_size>>> (convert, loader, storer);
302
- cudaDeviceSynchronize ();
303
- }
269
+ // TEST(TestReg2Shared, operand_C_float) {
270
+ // using Element = __half;
271
+ // using AccType = float;
272
+
273
+ // const int kRowRepeats = 4;
274
+ // const int kColRepeats = 8;
275
+ // const int kRows = 16 * kRowRepeats;
276
+ // const int kCols = 16 * kColRepeats;
277
+
278
+ // const int kWarpPerRow = 2;
279
+ // const int kWarpPerCol = 2;
280
+ // using WarpLayout = tl::RowMajor<kWarpPerRow, kWarpPerCol>;
281
+ // const int kThreads = tl::get_numel<WarpLayout> * 32;
282
+
283
+ // using SharedHalf = SharedTile<Element, tl::RowMajor<kRows, kCols>>;
284
+ // using RegHalf = RegTile<
285
+ // BaseTileRowMajor<Element>,
286
+ // tl::RowMajor<kRowRepeats / kWarpPerRow, kColRepeats / kWarpPerCol>>;
287
+
288
+ // using SharedFloat = SharedTile<AccType, tl::RowMajor<kRows, kCols>>;
289
+ // using RegFloat = RegTile<
290
+ // BaseTileRowMajor<AccType>,
291
+ // tl::RowMajor<kRowRepeats / kWarpPerRow, kColRepeats / kWarpPerCol>>;
292
+
293
+ // using ConvertHalf = compute::RegTileConvert<RegHalf, RegFloat>;
294
+ // ConvertHalf convert;
295
+
296
+ // using Loader = SharedToRegLoader<RegHalf, WarpLayout, WarpReuse::kCont>;
297
+ // Loader loader;
298
+
299
+ // using Storer = RegToSharedStorer<RegFloat, WarpLayout>;
300
+ // Storer storer;
301
+
302
+ // dim3 dim_grid(1, 1, 1);
303
+ // dim3 dim_block(kThreads, 1, 1);
304
+ // int shm_size = SharedHalf::kNumel * sizeof(Element) +
305
+ // SharedFloat::kNumel * sizeof(AccType);
306
+
307
+ // run_test_store_float<SharedHalf, RegHalf, SharedFloat, RegFloat,
308
+ // ConvertHalf, Loader, Storer>
309
+ // <<<dim_grid, dim_block, shm_size>>>(convert, loader, storer);
310
+ // cudaDeviceSynchronize();
311
+ // }
304
312
305
313
} // namespace tilefusion::testing
0 commit comments