Skip to content

Commit ab4174c

Browse files
committed
update.
1 parent ba0734f commit ab4174c

File tree

2 files changed

+125
-105
lines changed

2 files changed

+125
-105
lines changed

include/cell/copy/shared_to_register.hpp

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,12 @@ struct SharedToRegLoaderImpl<Shared, Reg_, kRowExec_, kColExec_,
3737
static constexpr int kRowExec = kRowExec_;
3838
static constexpr int kColExec = kColExec_;
3939

40-
DEVICE SharedToRegLoaderImpl()
41-
: base_tiles_(BaseTilesLayout{})
42-
, in_base_tile_(BaseTileSharedLayout{}) {}
40+
static constexpr int kSwizzledBlockRows = kRowExec * 16 / 8;
41+
static constexpr int kSwizzledBlockCols = kColExec * 16 / 64;
42+
43+
// DEVICE SharedToRegLoaderImpl()
44+
// : base_tiles_(BaseTilesLayout{})
45+
// , in_base_tile_(BaseTileSharedLayout{}) {}
4346

4447
DEVICE int2 get_base_tile_id(int offset) {
4548
// BaseTile is a 16 x 16 block.
@@ -68,26 +71,30 @@ struct SharedToRegLoaderImpl<Shared, Reg_, kRowExec_, kColExec_,
6871
DEVICE int get_swizzle_offset(int offset) {
6972
auto swizzled_tile_id = get_swizzled_tile_id(offset);
7073
auto in_swizzled_tile_id = get_in_swizzle_tile_id(offset);
71-
auto in_swizzle_offset =
72-
src_tile_(in_swizzled_tile_id.x, in_swizzled_tile_id.y);
73-
auto swizzled_offset = swizzled_tile_id.y * 64 +
74-
swizzled_tile_id.x * 8 * SharedCols +
75-
in_swizzle_offset;
74+
auto swizzled_offset =
75+
src_tile_(swizzled_tile_id.x, swizzled_tile_id.y) +
76+
in_src_tile_(in_swizzled_tile_id.x, in_swizzled_tile_id.y);
7677
return swizzled_offset;
7778
}
7879

7980
DEVICE void operator()(const DType* src, Reg& dst, int tile_offset) {
8081
int lane_row = this->lane_row_id();
8182
int lane_col = this->lane_col_id() * LoadMat::kNumPerAccess;
8283

83-
int lane_offset = in_base_tile_(lane_row, lane_col);
84+
// int lane_offset = in_base_tile_(lane_row, lane_col);
8485
int offset = 0;
8586

87+
if (thread0()) {
88+
printf("kRowExec: %d, kColExec: %d\n", kRowExec, kColExec);
89+
printf("kSwizzledBlockRows: %d, kSwizzledBlockCols: %d\n",
90+
kSwizzledBlockRows, kSwizzledBlockCols);
91+
}
92+
8693
#pragma unroll
8794
for (int i = 0; i < kRowExec; ++i) {
8895
#pragma unroll
8996
for (int j = 0; j < kColExec; ++j) {
90-
tile_offset = i * SharedCols * 16 + j * 16;
97+
tile_offset += i * SharedCols * 16 + j * 16;
9198
int thrd_offset =
9299
tile_offset + lane_row * SharedCols + lane_col;
93100
offset = get_swizzle_offset(thrd_offset);
@@ -117,14 +124,19 @@ struct SharedToRegLoaderImpl<Shared, Reg_, kRowExec_, kColExec_,
117124
}
118125

119126
private:
120-
using BaseTilesLayout =
121-
tl::MatrixLayout<kRowExec, kColExec, Shared::kRowStride,
122-
Shared::kColStride>;
123-
BaseTilesLayout base_tiles_;
127+
// using BaseTilesLayout =
128+
// tl::MatrixLayout<kRowExec, kColExec, Shared::kRowStride,
129+
// Shared::kColStride>;
130+
// BaseTilesLayout base_tiles_;
124131

125-
using BaseTileSharedLayout =
126-
tl::SharedLayoutWrapper<Shared, LoadMat::kAccessInBits>::Layout;
127-
BaseTileSharedLayout in_base_tile_;
132+
// using BaseTileSharedLayout =
133+
// tl::SharedLayoutWrapper<Shared, LoadMat::kAccessInBits>::Layout;
134+
// BaseTileSharedLayout in_base_tile_;
135+
136+
using SrcLayout =
137+
tl::MatrixLayout<kSwizzledBlockRows, kSwizzledBlockCols * 8,
138+
Shared::kRowStride, 64>;
139+
SrcLayout src_tile_;
128140

129141
using SwizzledBaseShape = traits::SwizzleBaseTileShape<DType>;
130142
static constexpr int kSwizzledRows = SwizzledBaseShape::kRows;
@@ -139,7 +151,7 @@ struct SharedToRegLoaderImpl<Shared, Reg_, kRowExec_, kColExec_,
139151

140152
using SharedLayout =
141153
std::conditional_t<Shared::kSwizzled, Swizzled, NonSwizzled>;
142-
SharedLayout src_tile_;
154+
SharedLayout in_src_tile_;
143155
};
144156

145157
/// @brief partial specialization for column-major shared memory tile.

tests/cpp/cell/test_s2r_copy.cu

Lines changed: 95 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ __global__ void run_test_load(Copy& copy) {
7070

7171
copy(s_tile, r_tile);
7272

73-
#if defined(DEBUG)
73+
// #if defined(DEBUG)
7474
if (thread0()) {
7575
r_tile.dump_value();
7676
}
77-
#endif
77+
// #endif
7878
}
7979

8080
template <typename Shared, typename Reg, typename Loader, typename Storer>
@@ -180,71 +180,79 @@ TEST(TestShared2Reg, operand_A) { // load mode for loading operand A in gemm
180180
cudaDeviceSynchronize();
181181
}
182182

183-
TEST(TestShared2Reg, operand_B) { // load mode for loading operand B in gemm
184-
using Element = __half;
183+
// TEST(TestShared2Reg, operand_B) { // load mode for loading operand B in gemm
184+
// using Element = __half;
185185

186-
using WarpLayout = tl::RowMajor<2, 2>;
187-
const int kThreads = tl::get_numel<WarpLayout> * 32;
186+
// using WarpLayout = tl::RowMajor<2, 2>;
187+
// const int kThreads = tl::get_numel<WarpLayout> * 32;
188188

189-
// a 32x64 row-major shared tile is equivalent to a 64x32 col-major tile
190-
using Shared = SharedTile<Element, tl::RowMajor<32, 64>>;
189+
// // a 32x64 row-major shared tile is equivalent to a 64x32 col-major tile
190+
// using Shared = SharedTile<Element, tl::RowMajor<32, 64>>;
191191

192-
// Each thread accesses 4x2 elements (the shape of `BaseHalfTileRowMajor`)
193-
// within a 16x16 `BaseTile`. These 4x2 elements are accessed 2x2 times
194-
// along each dimension, contributing to the final register tile handled by
195-
// a single thread.
196-
using Reg = RegTile<BaseTileColMajor<Element>, tl::ColMajor<2, 2>>;
197-
// In the `ColReuseCont` mode, warps in the same column repeatedly access
198-
// the same data.
199-
using Copy = SharedToRegLoader<Reg, WarpLayout, WarpReuse::kColReuseCont>;
200-
Copy copy;
192+
// // Each thread accesses 4x2 elements (the shape of
193+
// `BaseHalfTileRowMajor`)
194+
// // within a 16x16 `BaseTile`. These 4x2 elements are accessed 2x2 times
195+
// // along each dimension, contributing to the final register tile handled
196+
// by
197+
// // a single thread.
198+
// using Reg = RegTile<BaseTileColMajor<Element>, tl::ColMajor<2, 2>>;
199+
// // In the `ColReuseCont` mode, warps in the same column repeatedly access
200+
// // the same data.
201+
// using Copy = SharedToRegLoader<Reg, WarpLayout,
202+
// WarpReuse::kColReuseCont>; Copy copy;
201203

202-
dim3 dim_grid(1, 1, 1);
203-
dim3 dim_block(kThreads, 1, 1);
204-
int shm_size = Shared::kNumel * sizeof(Element);
204+
// dim3 dim_grid(1, 1, 1);
205+
// dim3 dim_block(kThreads, 1, 1);
206+
// int shm_size = Shared::kNumel * sizeof(Element);
205207

206-
run_test_load<Element, Shared, Reg, Copy>
207-
<<<dim_grid, dim_block, shm_size>>>(copy);
208-
cudaDeviceSynchronize();
209-
}
208+
// run_test_load<Element, Shared, Reg, Copy>
209+
// <<<dim_grid, dim_block, shm_size>>>(copy);
210+
// cudaDeviceSynchronize();
211+
// }
210212

211-
TEST(TestReg2Shared, operand_C_half) {
212-
using Element = __half;
213+
// TEST(TestReg2Shared, operand_C_half) {
214+
// using Element = __half;
213215

214-
using WarpLayout = tl::RowMajor<1, 1>;
215-
const int kThreads = tl::get_numel<WarpLayout> * 32;
216+
// using WarpLayout = tl::RowMajor<1, 1>;
217+
// const int kThreads = tl::get_numel<WarpLayout> * 32;
216218

217-
using Shared = SharedTile<Element, tl::RowMajor<16, 16>>;
218-
using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<1, 1>>;
219+
// // using Shared = SharedTile<Element, tl::RowMajor<16, 16>>;
220+
// using Shared = SharedTile<Element, tl::RowMajor<16, 64>>;
221+
// // using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<1, 1>>;
222+
// using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<1, 4>>;
219223

220-
using Loader = SharedToRegLoader<Reg, WarpLayout, WarpReuse::kCont>;
221-
Loader loader;
224+
// using Loader = SharedToRegLoader<Reg, WarpLayout, WarpReuse::kCont>;
225+
// Loader loader;
222226

223-
using Storer = RegToSharedStorer<Reg, WarpLayout>;
224-
Storer storer;
227+
// using Storer = RegToSharedStorer<Reg, WarpLayout>;
228+
// Storer storer;
225229

226-
dim3 dim_grid(1, 1, 1);
227-
dim3 dim_block(kThreads, 1, 1);
228-
int shm_size = Shared::kNumel * sizeof(Element);
230+
// dim3 dim_grid(1, 1, 1);
231+
// dim3 dim_block(kThreads, 1, 1);
232+
// int shm_size = Shared::kNumel * sizeof(Element);
229233

230-
run_test_store<Shared, Reg, Loader, Storer>
231-
<<<dim_grid, dim_block, shm_size>>>(loader, storer);
232-
cudaDeviceSynchronize();
233-
}
234+
// run_test_store<Shared, Reg, Loader, Storer>
235+
// <<<dim_grid, dim_block, shm_size>>>(loader, storer);
236+
// cudaDeviceSynchronize();
237+
// }
234238

235239
TEST(TestShared2Reg, operand_A_swizzle) {
236240
using Element = __half;
237241

238242
using WarpLayout = tl::RowMajor<1, 1>;
239243
const int kThreads = tl::get_numel<WarpLayout> * 32;
240244

241-
const int kRows = 64;
242-
const int kCols = 32;
245+
// const int kRows = 64;
246+
// const int kCols = 32;
247+
248+
const int kRows = 16;
249+
const int kCols = 64;
243250

244251
using SharedLayout = tl::RowMajor<kRows, kCols>;
245252
const bool kUseSwizzledLayout = true;
246253
using Shared = SharedTile<Element, SharedLayout, kUseSwizzledLayout>;
247-
using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<2, 2>>;
254+
// using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<2, 2>>;
255+
using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<1, 4>>;
248256

249257
using Copy = SharedToRegLoader<Reg, WarpLayout, WarpReuse::kRowReuseCont>;
250258
Copy copy;
@@ -258,48 +266,48 @@ TEST(TestShared2Reg, operand_A_swizzle) {
258266
cudaDeviceSynchronize();
259267
}
260268

261-
TEST(TestReg2Shared, operand_C_float) {
262-
using Element = __half;
263-
using AccType = float;
264-
265-
const int kRowRepeats = 4;
266-
const int kColRepeats = 8;
267-
const int kRows = 16 * kRowRepeats;
268-
const int kCols = 16 * kColRepeats;
269-
270-
const int kWarpPerRow = 2;
271-
const int kWarpPerCol = 2;
272-
using WarpLayout = tl::RowMajor<kWarpPerRow, kWarpPerCol>;
273-
const int kThreads = tl::get_numel<WarpLayout> * 32;
274-
275-
using SharedHalf = SharedTile<Element, tl::RowMajor<kRows, kCols>>;
276-
using RegHalf = RegTile<
277-
BaseTileRowMajor<Element>,
278-
tl::RowMajor<kRowRepeats / kWarpPerRow, kColRepeats / kWarpPerCol>>;
279-
280-
using SharedFloat = SharedTile<AccType, tl::RowMajor<kRows, kCols>>;
281-
using RegFloat = RegTile<
282-
BaseTileRowMajor<AccType>,
283-
tl::RowMajor<kRowRepeats / kWarpPerRow, kColRepeats / kWarpPerCol>>;
284-
285-
using ConvertHalf = compute::RegTileConvert<RegHalf, RegFloat>;
286-
ConvertHalf convert;
287-
288-
using Loader = SharedToRegLoader<RegHalf, WarpLayout, WarpReuse::kCont>;
289-
Loader loader;
290-
291-
using Storer = RegToSharedStorer<RegFloat, WarpLayout>;
292-
Storer storer;
293-
294-
dim3 dim_grid(1, 1, 1);
295-
dim3 dim_block(kThreads, 1, 1);
296-
int shm_size = SharedHalf::kNumel * sizeof(Element) +
297-
SharedFloat::kNumel * sizeof(AccType);
298-
299-
run_test_store_float<SharedHalf, RegHalf, SharedFloat, RegFloat,
300-
ConvertHalf, Loader, Storer>
301-
<<<dim_grid, dim_block, shm_size>>>(convert, loader, storer);
302-
cudaDeviceSynchronize();
303-
}
269+
// TEST(TestReg2Shared, operand_C_float) {
270+
// using Element = __half;
271+
// using AccType = float;
272+
273+
// const int kRowRepeats = 4;
274+
// const int kColRepeats = 8;
275+
// const int kRows = 16 * kRowRepeats;
276+
// const int kCols = 16 * kColRepeats;
277+
278+
// const int kWarpPerRow = 2;
279+
// const int kWarpPerCol = 2;
280+
// using WarpLayout = tl::RowMajor<kWarpPerRow, kWarpPerCol>;
281+
// const int kThreads = tl::get_numel<WarpLayout> * 32;
282+
283+
// using SharedHalf = SharedTile<Element, tl::RowMajor<kRows, kCols>>;
284+
// using RegHalf = RegTile<
285+
// BaseTileRowMajor<Element>,
286+
// tl::RowMajor<kRowRepeats / kWarpPerRow, kColRepeats / kWarpPerCol>>;
287+
288+
// using SharedFloat = SharedTile<AccType, tl::RowMajor<kRows, kCols>>;
289+
// using RegFloat = RegTile<
290+
// BaseTileRowMajor<AccType>,
291+
// tl::RowMajor<kRowRepeats / kWarpPerRow, kColRepeats / kWarpPerCol>>;
292+
293+
// using ConvertHalf = compute::RegTileConvert<RegHalf, RegFloat>;
294+
// ConvertHalf convert;
295+
296+
// using Loader = SharedToRegLoader<RegHalf, WarpLayout, WarpReuse::kCont>;
297+
// Loader loader;
298+
299+
// using Storer = RegToSharedStorer<RegFloat, WarpLayout>;
300+
// Storer storer;
301+
302+
// dim3 dim_grid(1, 1, 1);
303+
// dim3 dim_block(kThreads, 1, 1);
304+
// int shm_size = SharedHalf::kNumel * sizeof(Element) +
305+
// SharedFloat::kNumel * sizeof(AccType);
306+
307+
// run_test_store_float<SharedHalf, RegHalf, SharedFloat, RegFloat,
308+
// ConvertHalf, Loader, Storer>
309+
// <<<dim_grid, dim_block, shm_size>>>(convert, loader, storer);
310+
// cudaDeviceSynchronize();
311+
// }
304312

305313
} // namespace tilefusion::testing

0 commit comments

Comments
 (0)