@@ -54,20 +54,32 @@ struct GlobalToSharedLoaderImpl<Global_, Shared_, BaseShape_, kRowExec_,
54
54
static constexpr int kColExec = kColExec_ ;
55
55
56
56
DEVICE void operator ()(const DType* src, DType* dst) {
57
+ // TODO(KuangjuX): When the `WarpRow` is greater than 1, a swizzle block
58
+ // might be split by two warps, and a solution is needed to address this
59
+ // situation.
57
60
int row = lane_row_id ();
58
61
int col = lane_col_id () * kNumPerAccess ;
59
62
60
- // / the pointer offset inside a warp tile.
61
- int src_lane_offset = src_layout_ (row, col);
62
- int dst_lane_offset = dst_layout_ (row, col);
63
-
64
63
int src_offset = 0 , dst_offset = 0 ;
65
64
#pragma unroll
66
65
for (int i = 0 ; i < kRowExec ; ++i) {
67
66
#pragma unroll
68
67
for (int j = 0 ; j < kColExec ; ++j) {
68
+ int tile_i =
69
+ (i * BaseShape::kRows + row) / SwizzledBaseShape::kRows ;
70
+ int tile_j =
71
+ (j * BaseShape::kCols + col) / SwizzledBaseShape::kCols ;
72
+ int tile_row =
73
+ (i * BaseShape::kRows + row) % SwizzledBaseShape::kRows ;
74
+ int tile_col =
75
+ (j * BaseShape::kCols + col) % SwizzledBaseShape::kCols ;
76
+
77
+ // / the pointer offset inside a warp tile.
78
+ int src_lane_offset = src_tile_ (row, col);
79
+ int dst_tile_offset = dst_tile_ (tile_row, tile_col);
80
+
69
81
src_offset = src_base_tiles_ (i, j) + src_lane_offset;
70
- dst_offset = dst_base_tiles_ (i, j ) + dst_lane_offset ;
82
+ dst_offset = dst_base_tiles_ (tile_i, tile_j ) + dst_tile_offset ;
71
83
72
84
copy (src + src_offset, dst + dst_offset);
73
85
}
@@ -78,33 +90,42 @@ struct GlobalToSharedLoaderImpl<Global_, Shared_, BaseShape_, kRowExec_,
78
90
static constexpr int kNumPerAccess =
79
91
traits::AccessBase<DType>::kNumPerAccess ;
80
92
93
+ using SwizzledBaseShape = traits::SwizzleBaseTileShape<DType>;
94
+ static constexpr int kSwizzledRows = SwizzledBaseShape::kRows ;
95
+ static constexpr int kSwizzledCols = SwizzledBaseShape::kCols ;
96
+
97
+ static constexpr int kSwizzledRowExec =
98
+ kRowExec / (kSwizzledRows / BaseShape::kRows );
99
+ static constexpr int kSwizzledColExec =
100
+ kColExec / (kSwizzledCols / BaseShape::kCols );
101
+
81
102
using SrcBaseTilesLayout =
82
103
tl::MatrixLayout<kRowExec , kColExec ,
83
104
BaseShape::kRows * Global::kRowStride ,
84
105
BaseShape::kCols >;
85
106
SrcBaseTilesLayout src_base_tiles_;
86
107
87
- // a BaseTile is contiguously stored in shared memory
88
- using DstBaseTilesLayout =
89
- tl::MatrixLayout<kRowExec , kColExec ,
90
- BaseShape::kRows * Shared::kRowStride ,
91
- BaseShape::kNumel >;
92
- DstBaseTilesLayout dst_base_tiles_;
108
+ using DstSwizzledLayout =
109
+ tl::MatrixLayout<kSwizzledRowExec , kSwizzledColExec ,
110
+ kSwizzledRows * Shared::kRowStride , kSwizzledCols >;
111
+ DstSwizzledLayout dst_base_tiles_;
93
112
94
113
// Given a thread index, the GlobalLayout and SharedLayout below return the
95
114
// data offset from which the thread should load from the global memory tile
96
115
// and where to store it in the shared memory tile, respectively.
97
116
using GlobalLayout = tl::MatrixLayout<BaseShape::kRows , BaseShape::kCols ,
98
117
Global::kRowStride , 1 >;
99
- GlobalLayout src_layout_;
100
118
101
- using NonSwizzled = tl::RowMajor<BaseShape::kRows , BaseShape::kCols >;
102
- using Swizzled =
103
- tl::SwizzledRowMajor<traits::AccessBase<DType>::kAccessInBits ,
104
- BaseShape>;
119
+ // `src_tile_` is a basetile handled by a single warp.
120
+ GlobalLayout src_tile_;
121
+
122
+ using NonSwizzled =
123
+ tl::MatrixLayout<kSwizzledRows , kSwizzledCols , Shared::kRowStride , 1 >;
124
+ using Swizzled = SwizzledLayout<NonSwizzled, 3 , 3 , 3 >;
125
+
105
126
using SharedLayout =
106
127
std::conditional_t <Shared::kSwizzled , Swizzled, NonSwizzled>;
107
- SharedLayout dst_layout_ ;
128
+ SharedLayout dst_tile_ ;
108
129
109
130
DEVICE void copy (const DType* src, DType* dst) {
110
131
// a single memory access access 16 bytes
@@ -233,16 +254,24 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape, kRowExec_,
233
254
int row = lane_row_id ();
234
255
int col = lane_col_id () * kNumPerAccess ;
235
256
236
- // / the pointer offset inside a warp tile.
237
- int src_lane_offset = src_tile_ (row, col);
238
- int dst_lane_offset = dst_tile_ (row, col);
239
-
240
257
int src_offset = 0 , dst_offset = 0 ;
241
258
#pragma unroll
242
259
for (int i = 0 ; i < kRowExec ; ++i) {
243
260
#pragma unroll
244
261
for (int j = 0 ; j < kColExec ; ++j) {
245
- src_offset = src_base_tiles_ (i, j) + src_lane_offset;
262
+ int tile_i =
263
+ (i * BaseShape::kRows + row) / SwizzledBaseShape::kRows ;
264
+ int tile_j =
265
+ (j * BaseShape::kCols + col) / SwizzledBaseShape::kCols ;
266
+ int tile_row =
267
+ (i * BaseShape::kRows + row) % SwizzledBaseShape::kRows ;
268
+ int tile_col =
269
+ (j * BaseShape::kCols + col) % SwizzledBaseShape::kCols ;
270
+
271
+ int src_tile_offset = src_tile_ (tile_row, tile_col);
272
+ int dst_lane_offset = dst_tile_ (row, col);
273
+
274
+ src_offset = src_base_tiles_ (tile_i, tile_j) + src_tile_offset;
246
275
dst_offset = dst_base_tiles_ (i, j) + dst_lane_offset;
247
276
248
277
copy (src + src_offset, dst + dst_offset);
@@ -251,12 +280,19 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape, kRowExec_,
251
280
}
252
281
253
282
private:
254
- // a SharedTile is contiguously stored
255
- using SrcBaseTilesLayout =
256
- tl::MatrixLayout<kRowExec , kColExec ,
257
- BaseShape::kRows * Shared::kRowStride ,
258
- BaseShape::kNumel >;
259
- SrcBaseTilesLayout src_base_tiles_;
283
+ using SwizzledBaseShape = traits::SwizzleBaseTileShape<DType>;
284
+ static constexpr int kSwizzledRows = SwizzledBaseShape::kRows ;
285
+ static constexpr int kSwizzledCols = SwizzledBaseShape::kCols ;
286
+
287
+ static constexpr int kSwizzledRowExec =
288
+ kRowExec / (kSwizzledRows / BaseShape::kRows );
289
+ static constexpr int kSwizzledColExec =
290
+ kColExec / (kSwizzledCols / BaseShape::kCols );
291
+
292
+ using SrcSwizzledLayout =
293
+ tl::MatrixLayout<kSwizzledRowExec , kSwizzledColExec ,
294
+ kSwizzledRows * Shared::kRowStride , kSwizzledCols >;
295
+ SrcSwizzledLayout src_base_tiles_;
260
296
261
297
using DstBaseTilesLayout =
262
298
tl::MatrixLayout<kRowExec , kColExec ,
@@ -273,15 +309,15 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape, kRowExec_,
273
309
static constexpr int kNumPerAccess =
274
310
traits::AccessBase<DType>::kNumPerAccess ;
275
311
276
- using NonSwizzled = tl::RowMajor<BaseShape::kRows , BaseShape::kCols >;
277
- using Swizzled = tl::SwizzledRowMajor<kAccessInBits , BaseShape>;
312
+ using NonSwizzled =
313
+ tl::MatrixLayout<kSwizzledRows , kSwizzledCols , Shared::kRowStride , 1 >;
314
+ using Swizzled = SwizzledLayout<NonSwizzled, 3 , 3 , 3 >;
278
315
using SharedLayout =
279
316
std::conditional_t <Shared::kSwizzled , Swizzled, NonSwizzled>;
280
317
SharedLayout src_tile_;
281
318
282
- using GlobalLayout =
283
- tl::MatrixLayout<BaseShape::kRows , BaseShape::kCols , Global::kRowStride ,
284
- Global::kColStride >;
319
+ using GlobalLayout = tl::MatrixLayout<BaseShape::kRows , BaseShape::kCols ,
320
+ Global::kRowStride , 1 >;
285
321
GlobalLayout dst_tile_;
286
322
287
323
// / @brief returns the lane col of the current thread within a warp.
@@ -364,8 +400,14 @@ struct GlobalToSharedLoader {
364
400
// warp-level tile shape instead of using a fixed 16x16 `BaseShape`. using
365
401
// WarpShape =
366
402
// warp::WarpTileShape<DType, typename Shared::Layout, Shared::kType>;
367
- using WarpShape =
368
- warp::WarpTileShape<DType, tl::RowMajor<16 , 16 >, Shared::kType >;
403
+ // using WarpShape =
404
+ // warp::WarpTileShape<DType, tl::RowMajor<16, 16>, Shared::kType>;
405
+
406
+ // KuangjuX: Use `4x64` in RowMajor and `64x4` in ColMajor.
407
+ static constexpr bool kRowMajor = Shared::kType == tl::Layout::kRowMajor ;
408
+ using BaseTile =
409
+ std::conditional_t <kRowMajor , tl::RowMajor<4 , 64 >, tl::ColMajor<64 , 4 >>;
410
+ using WarpShape = warp::WarpTileShape<DType, BaseTile, Shared::kType >;
369
411
370
412
static_assert (Shared::kRows % WarpShape::kRows == 0 ,
371
413
" Shared::kRows must be divisible by WarpShape::kRows." );
@@ -394,6 +436,7 @@ struct GlobalToSharedLoader {
394
436
const DType* src_ptr = src.data ();
395
437
DType* dst_ptr = dst.mutable_data ();
396
438
439
+ // get warp offset for global and shared memory
397
440
int offset_src = global_offset_.template get_warp_offset <Global>();
398
441
int offset_dst = shared_offset_.get_warp_offset ();
399
442
@@ -423,8 +466,15 @@ struct SharedToGlobalStorer {
423
466
424
467
// FIXME(ying): uncomment the following lines to automatically infer the
425
468
// warp-level tile shape instead of using a fixed 16x16 `BaseShape`.
426
- using BaseShape =
427
- warp::WarpTileShape<DType, tl::RowMajor<16 , 16 >, Shared::kType >;
469
+ // using BaseShape =
470
+ // warp::WarpTileShape<DType, tl::RowMajor<16, 16>, Shared::kType>;
471
+
472
+ // KuangjuX: Use `4x64` in RowMajor and `64x4` in ColMajor.
473
+
474
+ static constexpr bool kRowMajor = Shared::kType == tl::Layout::kRowMajor ;
475
+ using BaseTile =
476
+ std::conditional_t <kRowMajor , tl::RowMajor<4 , 64 >, tl::ColMajor<64 , 4 >>;
477
+ using BaseShape = warp::WarpTileShape<DType, BaseTile, Shared::kType >;
428
478
429
479
static_assert (Shared::kRows % BaseShape::kRows == 0 ,
430
480
" Shared::kRows must be divisible by BaseShape::kRows." );
@@ -433,9 +483,9 @@ struct SharedToGlobalStorer {
433
483
434
484
static const WarpReuse kMode = WarpReuse::kCont ; // warp reuse mode
435
485
486
+ using GlobalOffset = warp::GlobalOffsetHelper<WarpLayout, kMode >;
436
487
using SharedOffset =
437
488
warp::SharedOffsetHelper<WarpLayout, BaseShape, Shared, kMode >;
438
- using GlobalOffset = warp::GlobalOffsetHelper<WarpLayout, kMode >;
439
489
440
490
using ExecCounter = warp::ExecCounter<BaseShape, Shared, WarpLayout, kMode >;
441
491
0 commit comments