Skip to content

Commit 5bb8f7b

Browse files
authored
Refactor(cell): GMEM to SMEM Loader/Storer with SwizzleLayout in RowMajor (microsoft#44)
* Add Swizzle<3,3,3> Tile Shape. * reimplement a naive global to shared loader. * Update Global to Shared Load. * fix g2s load. * Refactor Shared to Global Storer. * hotfix for SharedOffsetHelper. * pass more tests. * pass more testcases. * chore: undo test_utils. * Add float swizzle basetile.
1 parent 953b6e0 commit 5bb8f7b

File tree

8 files changed

+316
-147
lines changed

8 files changed

+316
-147
lines changed

.vscode/settings.json

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,66 @@
1313
"limits": "cpp",
1414
"semaphore": "cpp",
1515
"regex": "cpp",
16-
"tuple": "cpp"
16+
"tuple": "cpp",
17+
"cctype": "cpp",
18+
"clocale": "cpp",
19+
"cmath": "cpp",
20+
"csignal": "cpp",
21+
"cstdarg": "cpp",
22+
"cstddef": "cpp",
23+
"cstdio": "cpp",
24+
"cstdlib": "cpp",
25+
"cstring": "cpp",
26+
"ctime": "cpp",
27+
"cwchar": "cpp",
28+
"cwctype": "cpp",
29+
"any": "cpp",
30+
"atomic": "cpp",
31+
"bit": "cpp",
32+
"cfenv": "cpp",
33+
"cinttypes": "cpp",
34+
"codecvt": "cpp",
35+
"compare": "cpp",
36+
"complex": "cpp",
37+
"concepts": "cpp",
38+
"condition_variable": "cpp",
39+
"cstdint": "cpp",
40+
"deque": "cpp",
41+
"forward_list": "cpp",
42+
"list": "cpp",
43+
"map": "cpp",
44+
"set": "cpp",
45+
"unordered_map": "cpp",
46+
"unordered_set": "cpp",
47+
"vector": "cpp",
48+
"exception": "cpp",
49+
"algorithm": "cpp",
50+
"functional": "cpp",
51+
"iterator": "cpp",
52+
"memory": "cpp",
53+
"memory_resource": "cpp",
54+
"numeric": "cpp",
55+
"optional": "cpp",
56+
"ratio": "cpp",
57+
"system_error": "cpp",
58+
"type_traits": "cpp",
59+
"fstream": "cpp",
60+
"iomanip": "cpp",
61+
"iosfwd": "cpp",
62+
"iostream": "cpp",
63+
"istream": "cpp",
64+
"mutex": "cpp",
65+
"new": "cpp",
66+
"numbers": "cpp",
67+
"ostream": "cpp",
68+
"sstream": "cpp",
69+
"stdexcept": "cpp",
70+
"stop_token": "cpp",
71+
"streambuf": "cpp",
72+
"thread": "cpp",
73+
"typeindex": "cpp",
74+
"typeinfo": "cpp",
75+
"variant": "cpp"
1776
},
1877
"gotoSymbolStack.currentStackPosition": 0,
1978
"gotoSymbolStack.maxStackPosition": 0,

include/cell/copy/global_to_shared.hpp

Lines changed: 88 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -54,20 +54,32 @@ struct GlobalToSharedLoaderImpl<Global_, Shared_, BaseShape_, kRowExec_,
5454
static constexpr int kColExec = kColExec_;
5555

5656
DEVICE void operator()(const DType* src, DType* dst) {
57+
// TODO(KuangjuX): When the `WarpRow` is greater than 1, a swizzle block
58+
// might be split by two warps, and a solution is needed to address this
59+
// situation.
5760
int row = lane_row_id();
5861
int col = lane_col_id() * kNumPerAccess;
5962

60-
/// the pointer offset inside a warp tile.
61-
int src_lane_offset = src_layout_(row, col);
62-
int dst_lane_offset = dst_layout_(row, col);
63-
6463
int src_offset = 0, dst_offset = 0;
6564
#pragma unroll
6665
for (int i = 0; i < kRowExec; ++i) {
6766
#pragma unroll
6867
for (int j = 0; j < kColExec; ++j) {
68+
int tile_i =
69+
(i * BaseShape::kRows + row) / SwizzledBaseShape::kRows;
70+
int tile_j =
71+
(j * BaseShape::kCols + col) / SwizzledBaseShape::kCols;
72+
int tile_row =
73+
(i * BaseShape::kRows + row) % SwizzledBaseShape::kRows;
74+
int tile_col =
75+
(j * BaseShape::kCols + col) % SwizzledBaseShape::kCols;
76+
77+
/// the pointer offset inside a warp tile.
78+
int src_lane_offset = src_tile_(row, col);
79+
int dst_tile_offset = dst_tile_(tile_row, tile_col);
80+
6981
src_offset = src_base_tiles_(i, j) + src_lane_offset;
70-
dst_offset = dst_base_tiles_(i, j) + dst_lane_offset;
82+
dst_offset = dst_base_tiles_(tile_i, tile_j) + dst_tile_offset;
7183

7284
copy(src + src_offset, dst + dst_offset);
7385
}
@@ -78,33 +90,42 @@ struct GlobalToSharedLoaderImpl<Global_, Shared_, BaseShape_, kRowExec_,
7890
static constexpr int kNumPerAccess =
7991
traits::AccessBase<DType>::kNumPerAccess;
8092

93+
using SwizzledBaseShape = traits::SwizzleBaseTileShape<DType>;
94+
static constexpr int kSwizzledRows = SwizzledBaseShape::kRows;
95+
static constexpr int kSwizzledCols = SwizzledBaseShape::kCols;
96+
97+
static constexpr int kSwizzledRowExec =
98+
kRowExec / (kSwizzledRows / BaseShape::kRows);
99+
static constexpr int kSwizzledColExec =
100+
kColExec / (kSwizzledCols / BaseShape::kCols);
101+
81102
using SrcBaseTilesLayout =
82103
tl::MatrixLayout<kRowExec, kColExec,
83104
BaseShape::kRows * Global::kRowStride,
84105
BaseShape::kCols>;
85106
SrcBaseTilesLayout src_base_tiles_;
86107

87-
// a BaseTile is contiguously stored in shared memory
88-
using DstBaseTilesLayout =
89-
tl::MatrixLayout<kRowExec, kColExec,
90-
BaseShape::kRows * Shared::kRowStride,
91-
BaseShape::kNumel>;
92-
DstBaseTilesLayout dst_base_tiles_;
108+
using DstSwizzledLayout =
109+
tl::MatrixLayout<kSwizzledRowExec, kSwizzledColExec,
110+
kSwizzledRows * Shared::kRowStride, kSwizzledCols>;
111+
DstSwizzledLayout dst_base_tiles_;
93112

94113
// Given a thread index, the GlobalLayout and SharedLayout below return the
95114
// data offset from which the thread should load from the global memory tile
96115
// and where to store it in the shared memory tile, respectively.
97116
using GlobalLayout = tl::MatrixLayout<BaseShape::kRows, BaseShape::kCols,
98117
Global::kRowStride, 1>;
99-
GlobalLayout src_layout_;
100118

101-
using NonSwizzled = tl::RowMajor<BaseShape::kRows, BaseShape::kCols>;
102-
using Swizzled =
103-
tl::SwizzledRowMajor<traits::AccessBase<DType>::kAccessInBits,
104-
BaseShape>;
119+
// `src_tile_` is a basetile handled by a single warp.
120+
GlobalLayout src_tile_;
121+
122+
using NonSwizzled =
123+
tl::MatrixLayout<kSwizzledRows, kSwizzledCols, Shared::kRowStride, 1>;
124+
using Swizzled = SwizzledLayout<NonSwizzled, 3, 3, 3>;
125+
105126
using SharedLayout =
106127
std::conditional_t<Shared::kSwizzled, Swizzled, NonSwizzled>;
107-
SharedLayout dst_layout_;
128+
SharedLayout dst_tile_;
108129

109130
DEVICE void copy(const DType* src, DType* dst) {
110131
// a single memory access access 16 bytes
@@ -233,16 +254,24 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape, kRowExec_,
233254
int row = lane_row_id();
234255
int col = lane_col_id() * kNumPerAccess;
235256

236-
/// the pointer offset inside a warp tile.
237-
int src_lane_offset = src_tile_(row, col);
238-
int dst_lane_offset = dst_tile_(row, col);
239-
240257
int src_offset = 0, dst_offset = 0;
241258
#pragma unroll
242259
for (int i = 0; i < kRowExec; ++i) {
243260
#pragma unroll
244261
for (int j = 0; j < kColExec; ++j) {
245-
src_offset = src_base_tiles_(i, j) + src_lane_offset;
262+
int tile_i =
263+
(i * BaseShape::kRows + row) / SwizzledBaseShape::kRows;
264+
int tile_j =
265+
(j * BaseShape::kCols + col) / SwizzledBaseShape::kCols;
266+
int tile_row =
267+
(i * BaseShape::kRows + row) % SwizzledBaseShape::kRows;
268+
int tile_col =
269+
(j * BaseShape::kCols + col) % SwizzledBaseShape::kCols;
270+
271+
int src_tile_offset = src_tile_(tile_row, tile_col);
272+
int dst_lane_offset = dst_tile_(row, col);
273+
274+
src_offset = src_base_tiles_(tile_i, tile_j) + src_tile_offset;
246275
dst_offset = dst_base_tiles_(i, j) + dst_lane_offset;
247276

248277
copy(src + src_offset, dst + dst_offset);
@@ -251,12 +280,19 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape, kRowExec_,
251280
}
252281

253282
private:
254-
// a SharedTile is contiguously stored
255-
using SrcBaseTilesLayout =
256-
tl::MatrixLayout<kRowExec, kColExec,
257-
BaseShape::kRows * Shared::kRowStride,
258-
BaseShape::kNumel>;
259-
SrcBaseTilesLayout src_base_tiles_;
283+
using SwizzledBaseShape = traits::SwizzleBaseTileShape<DType>;
284+
static constexpr int kSwizzledRows = SwizzledBaseShape::kRows;
285+
static constexpr int kSwizzledCols = SwizzledBaseShape::kCols;
286+
287+
static constexpr int kSwizzledRowExec =
288+
kRowExec / (kSwizzledRows / BaseShape::kRows);
289+
static constexpr int kSwizzledColExec =
290+
kColExec / (kSwizzledCols / BaseShape::kCols);
291+
292+
using SrcSwizzledLayout =
293+
tl::MatrixLayout<kSwizzledRowExec, kSwizzledColExec,
294+
kSwizzledRows * Shared::kRowStride, kSwizzledCols>;
295+
SrcSwizzledLayout src_base_tiles_;
260296

261297
using DstBaseTilesLayout =
262298
tl::MatrixLayout<kRowExec, kColExec,
@@ -273,15 +309,15 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape, kRowExec_,
273309
static constexpr int kNumPerAccess =
274310
traits::AccessBase<DType>::kNumPerAccess;
275311

276-
using NonSwizzled = tl::RowMajor<BaseShape::kRows, BaseShape::kCols>;
277-
using Swizzled = tl::SwizzledRowMajor<kAccessInBits, BaseShape>;
312+
using NonSwizzled =
313+
tl::MatrixLayout<kSwizzledRows, kSwizzledCols, Shared::kRowStride, 1>;
314+
using Swizzled = SwizzledLayout<NonSwizzled, 3, 3, 3>;
278315
using SharedLayout =
279316
std::conditional_t<Shared::kSwizzled, Swizzled, NonSwizzled>;
280317
SharedLayout src_tile_;
281318

282-
using GlobalLayout =
283-
tl::MatrixLayout<BaseShape::kRows, BaseShape::kCols, Global::kRowStride,
284-
Global::kColStride>;
319+
using GlobalLayout = tl::MatrixLayout<BaseShape::kRows, BaseShape::kCols,
320+
Global::kRowStride, 1>;
285321
GlobalLayout dst_tile_;
286322

287323
/// @brief returns the lane col of the current thread within a warp.
@@ -364,8 +400,14 @@ struct GlobalToSharedLoader {
364400
// warp-level tile shape instead of using a fixed 16x16 `BaseShape`. using
365401
// WarpShape =
366402
// warp::WarpTileShape<DType, typename Shared::Layout, Shared::kType>;
367-
using WarpShape =
368-
warp::WarpTileShape<DType, tl::RowMajor<16, 16>, Shared::kType>;
403+
// using WarpShape =
404+
// warp::WarpTileShape<DType, tl::RowMajor<16, 16>, Shared::kType>;
405+
406+
// KuangjuX: Use `4x64` in RowMajor and `64x4` in ColMajor.
407+
static constexpr bool kRowMajor = Shared::kType == tl::Layout::kRowMajor;
408+
using BaseTile =
409+
std::conditional_t<kRowMajor, tl::RowMajor<4, 64>, tl::ColMajor<64, 4>>;
410+
using WarpShape = warp::WarpTileShape<DType, BaseTile, Shared::kType>;
369411

370412
static_assert(Shared::kRows % WarpShape::kRows == 0,
371413
"Shared::kRows must be divisible by WarpShape::kRows.");
@@ -394,6 +436,7 @@ struct GlobalToSharedLoader {
394436
const DType* src_ptr = src.data();
395437
DType* dst_ptr = dst.mutable_data();
396438

439+
// get warp offset for global and shared memory
397440
int offset_src = global_offset_.template get_warp_offset<Global>();
398441
int offset_dst = shared_offset_.get_warp_offset();
399442

@@ -423,8 +466,15 @@ struct SharedToGlobalStorer {
423466

424467
// FIXME(ying): uncomment the following lines to automatically infer the
425468
// warp-level tile shape instead of using a fixed 16x16 `BaseShape`.
426-
using BaseShape =
427-
warp::WarpTileShape<DType, tl::RowMajor<16, 16>, Shared::kType>;
469+
// using BaseShape =
470+
// warp::WarpTileShape<DType, tl::RowMajor<16, 16>, Shared::kType>;
471+
472+
// KuangjuX: Use `4x64` in RowMajor and `64x4` in ColMajor.
473+
474+
static constexpr bool kRowMajor = Shared::kType == tl::Layout::kRowMajor;
475+
using BaseTile =
476+
std::conditional_t<kRowMajor, tl::RowMajor<4, 64>, tl::ColMajor<64, 4>>;
477+
using BaseShape = warp::WarpTileShape<DType, BaseTile, Shared::kType>;
428478

429479
static_assert(Shared::kRows % BaseShape::kRows == 0,
430480
"Shared::kRows must be divisible by BaseShape::kRows.");
@@ -433,9 +483,9 @@ struct SharedToGlobalStorer {
433483

434484
static const WarpReuse kMode = WarpReuse::kCont; // warp reuse mode
435485

486+
using GlobalOffset = warp::GlobalOffsetHelper<WarpLayout, kMode>;
436487
using SharedOffset =
437488
warp::SharedOffsetHelper<WarpLayout, BaseShape, Shared, kMode>;
438-
using GlobalOffset = warp::GlobalOffsetHelper<WarpLayout, kMode>;
439489

440490
using ExecCounter = warp::ExecCounter<BaseShape, Shared, WarpLayout, kMode>;
441491

include/cell/copy/warp.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -329,9 +329,10 @@ template <typename WarpLayout_, typename WarpShape_, typename Shared_,
329329
struct SharedOffsetHelper<WarpLayout_, WarpShape_, Shared_, kMode_,
330330
tl::Layout::kRowMajor, false> {
331331
DEVICE int get_warp_offset() {
332-
int tile_id = warp_row_id<WarpLayout>() * kRowStride +
333-
warp_col_id<WarpLayout>() * kColStride;
334-
return tile_id * WarpShape::kNumel;
332+
// TODO(KuangjuX): hotfix this.
333+
return warp_row_id<WarpLayout>() * kRowStride * WarpShape::kRows *
334+
Shared::kCols +
335+
warp_col_id<WarpLayout>() * kColStride * WarpShape::kCols;
335336
}
336337

337338
private:
@@ -343,8 +344,8 @@ struct SharedOffsetHelper<WarpLayout_, WarpShape_, Shared_, kMode_,
343344
constexpr static int kTilePerRow = Shared::kRows / WarpShape::kRows;
344345
constexpr static int kTilePerCol = Shared::kCols / WarpShape::kCols;
345346

346-
constexpr static int kRowStride =
347-
kTilePerRow / tl::num_rows<WarpLayout> * kTilePerCol;
347+
// TODO(KuangjuX): hotfix this.
348+
constexpr static int kRowStride = kTilePerRow / tl::num_rows<WarpLayout>;
348349
constexpr static int kColStride = kTilePerCol / tl::num_cols<WarpLayout>;
349350
};
350351

@@ -355,6 +356,7 @@ struct SharedOffsetHelper<WarpLayout_, WarpShape_, Shared_, kMode_,
355356
DEVICE int get_warp_offset() {
356357
int tile_id = warp_row_id<WarpLayout>() * kRowStride +
357358
warp_col_id<WarpLayout>() * kColStride;
359+
358360
return tile_id * WarpShape::kNumel;
359361
}
360362

include/traits/base.hpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,51 @@ struct BaseTileShape {
4949
static constexpr int kCols = kTileSize;
5050
static constexpr int kNumel = kRows * kCols;
5151
};
52+
53+
/**
54+
* @brief Single Warp load shape from global memory to shared memory.
55+
*/
56+
template <typename Element>
57+
requires BaseType<Element>
58+
struct GMemCopyShape;
59+
60+
template <>
61+
struct GMemCopyShape<__half> {
62+
static constexpr int kRows = 64;
63+
static constexpr int kCols = 4;
64+
static constexpr int kNumel = kRows * kCols;
65+
};
66+
67+
template <>
68+
struct GMemCopyShape<float> {
69+
static constexpr int kRows = 32;
70+
static constexpr int kCols = 4;
71+
static constexpr int kNumel = kRows * kCols;
72+
};
73+
74+
/**
75+
* @brief The base tile shape for Swizzle<3, 3, 3>.
76+
*/
77+
template <typename Element>
78+
requires BaseType<Element>
79+
struct SwizzleBaseTileShape;
80+
81+
template <>
82+
struct SwizzleBaseTileShape<__half> {
83+
using DType = __half;
84+
85+
static constexpr int kRows = 8;
86+
static constexpr int kCols = 64;
87+
static constexpr int kNumel = kRows * kCols;
88+
};
89+
90+
template <>
91+
struct SwizzleBaseTileShape<float> {
92+
using DType = float;
93+
94+
static constexpr int kRows = 8;
95+
static constexpr int kCols = 32;
96+
static constexpr int kNumel = kRows * kCols;
97+
};
98+
5299
} // namespace tilefusion::traits

0 commit comments

Comments
 (0)