Skip to content

Commit 61ba350

Browse files
committed
fix s2r loader.
1 parent ab4174c commit 61ba350

File tree

2 files changed

+56
-60
lines changed

2 files changed

+56
-60
lines changed

include/cell/copy/shared_to_register.hpp

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -84,24 +84,13 @@ struct SharedToRegLoaderImpl<Shared, Reg_, kRowExec_, kColExec_,
8484
// int lane_offset = in_base_tile_(lane_row, lane_col);
8585
int offset = 0;
8686

87-
if (thread0()) {
88-
printf("kRowExec: %d, kColExec: %d\n", kRowExec, kColExec);
89-
printf("kSwizzledBlockRows: %d, kSwizzledBlockCols: %d\n",
90-
kSwizzledBlockRows, kSwizzledBlockCols);
91-
}
92-
9387
#pragma unroll
9488
for (int i = 0; i < kRowExec; ++i) {
9589
#pragma unroll
9690
for (int j = 0; j < kColExec; ++j) {
97-
tile_offset += i * SharedCols * 16 + j * 16;
98-
int thrd_offset =
99-
tile_offset + lane_row * SharedCols + lane_col;
91+
int thrd_offset = tile_offset + i * SharedCols * 16 + j * 16 +
92+
lane_row * SharedCols + lane_col;
10093
offset = get_swizzle_offset(thrd_offset);
101-
// auto base_tile_id = get_base_tile_id(tile_offset);
102-
// auto swizzled_tile_id = get_swizzled_tile_id(tile_offset);
103-
// auto in_swizzled_tile_id =
104-
// get_in_swizzle_tile_id(tile_offset);
10594

10695
// if (thread0()) {
10796
// printf("i: %d, j: %d\n", i, j);
@@ -133,9 +122,8 @@ struct SharedToRegLoaderImpl<Shared, Reg_, kRowExec_, kColExec_,
133122
// tl::SharedLayoutWrapper<Shared, LoadMat::kAccessInBits>::Layout;
134123
// BaseTileSharedLayout in_base_tile_;
135124

136-
using SrcLayout =
137-
tl::MatrixLayout<kSwizzledBlockRows, kSwizzledBlockCols * 8,
138-
Shared::kRowStride, 64>;
125+
using SrcLayout = tl::MatrixLayout<kSwizzledBlockRows, kSwizzledBlockCols,
126+
Shared::kRowStride * 8, 64>;
139127
SrcLayout src_tile_;
140128

141129
using SwizzledBaseShape = traits::SwizzleBaseTileShape<DType>;
@@ -240,8 +228,11 @@ struct RegToSharedStorerImpl<Reg_, Shared_, kRowExec_, kColExec_,
240228
private:
241229
using BaseShape = BaseTileShape<DType>;
242230

231+
// static constexpr int kRowStride = BaseShape::kRows * Shared::kRowStride;
232+
// static constexpr int kColStride = BaseShape::kNumel;
233+
243234
static constexpr int kRowStride = BaseShape::kRows * Shared::kRowStride;
244-
static constexpr int kColStride = BaseShape::kNumel;
235+
static constexpr int kColStride = BaseShape::kCols;
245236
};
246237

247238
template <typename Reg_, typename Shared_, const int kRowExec_,

tests/cpp/cell/test_s2r_copy.cu

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,16 @@ __global__ void run_test_load(Copy& copy) {
7070

7171
copy(s_tile, r_tile);
7272

73-
// #if defined(DEBUG)
73+
#if defined(DEBUG)
7474
if (thread0()) {
7575
r_tile.dump_value();
7676
}
77-
// #endif
77+
#endif
78+
79+
if (threadIdx.x == 4) {
80+
printf("threadIdx.x: %d\n", threadIdx.x);
81+
r_tile.dump_value();
82+
}
7883
}
7984

8085
template <typename Shared, typename Reg, typename Loader, typename Storer>
@@ -210,62 +215,62 @@ TEST(TestShared2Reg, operand_A) { // load mode for loading operand A in gemm
210215
// cudaDeviceSynchronize();
211216
// }
212217

213-
// TEST(TestReg2Shared, operand_C_half) {
214-
// using Element = __half;
215-
216-
// using WarpLayout = tl::RowMajor<1, 1>;
217-
// const int kThreads = tl::get_numel<WarpLayout> * 32;
218-
219-
// // using Shared = SharedTile<Element, tl::RowMajor<16, 16>>;
220-
// using Shared = SharedTile<Element, tl::RowMajor<16, 64>>;
221-
// // using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<1, 1>>;
222-
// using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<1, 4>>;
223-
224-
// using Loader = SharedToRegLoader<Reg, WarpLayout, WarpReuse::kCont>;
225-
// Loader loader;
226-
227-
// using Storer = RegToSharedStorer<Reg, WarpLayout>;
228-
// Storer storer;
229-
230-
// dim3 dim_grid(1, 1, 1);
231-
// dim3 dim_block(kThreads, 1, 1);
232-
// int shm_size = Shared::kNumel * sizeof(Element);
233-
234-
// run_test_store<Shared, Reg, Loader, Storer>
235-
// <<<dim_grid, dim_block, shm_size>>>(loader, storer);
236-
// cudaDeviceSynchronize();
237-
// }
238-
239-
TEST(TestShared2Reg, operand_A_swizzle) {
218+
TEST(TestReg2Shared, operand_C_half) {
240219
using Element = __half;
241220

242221
using WarpLayout = tl::RowMajor<1, 1>;
243222
const int kThreads = tl::get_numel<WarpLayout> * 32;
244223

245-
// const int kRows = 64;
246-
// const int kCols = 32;
247-
248-
const int kRows = 16;
249-
const int kCols = 64;
250-
251-
using SharedLayout = tl::RowMajor<kRows, kCols>;
252-
const bool kUseSwizzledLayout = true;
253-
using Shared = SharedTile<Element, SharedLayout, kUseSwizzledLayout>;
254-
// using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<2, 2>>;
224+
// using Shared = SharedTile<Element, tl::RowMajor<16, 16>>;
225+
using Shared = SharedTile<Element, tl::RowMajor<16, 64>>;
226+
// using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<1, 1>>;
255227
using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<1, 4>>;
256228

257-
using Copy = SharedToRegLoader<Reg, WarpLayout, WarpReuse::kRowReuseCont>;
258-
Copy copy;
229+
using Loader = SharedToRegLoader<Reg, WarpLayout, WarpReuse::kCont>;
230+
Loader loader;
231+
232+
using Storer = RegToSharedStorer<Reg, WarpLayout>;
233+
Storer storer;
259234

260235
dim3 dim_grid(1, 1, 1);
261236
dim3 dim_block(kThreads, 1, 1);
262237
int shm_size = Shared::kNumel * sizeof(Element);
263238

264-
run_test_load<Element, Shared, Reg, Copy>
265-
<<<dim_grid, dim_block, shm_size>>>(copy);
239+
run_test_store<Shared, Reg, Loader, Storer>
240+
<<<dim_grid, dim_block, shm_size>>>(loader, storer);
266241
cudaDeviceSynchronize();
267242
}
268243

244+
// TEST(TestShared2Reg, operand_A_swizzle) {
245+
// using Element = __half;
246+
247+
// using WarpLayout = tl::RowMajor<1, 1>;
248+
// const int kThreads = tl::get_numel<WarpLayout> * 32;
249+
250+
// // const int kRows = 64;
251+
// // const int kCols = 32;
252+
253+
// const int kRows = 16;
254+
// const int kCols = 64;
255+
256+
// using SharedLayout = tl::RowMajor<kRows, kCols>;
257+
// const bool kUseSwizzledLayout = true;
258+
// using Shared = SharedTile<Element, SharedLayout, kUseSwizzledLayout>;
259+
// // using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<2, 2>>;
260+
// using Reg = RegTile<BaseTileRowMajor<Element>, tl::RowMajor<1, 4>>;
261+
262+
// using Copy = SharedToRegLoader<Reg, WarpLayout,
263+
// WarpReuse::kRowReuseCont>; Copy copy;
264+
265+
// dim3 dim_grid(1, 1, 1);
266+
// dim3 dim_block(kThreads, 1, 1);
267+
// int shm_size = Shared::kNumel * sizeof(Element);
268+
269+
// run_test_load<Element, Shared, Reg, Copy>
270+
// <<<dim_grid, dim_block, shm_size>>>(copy);
271+
// cudaDeviceSynchronize();
272+
// }
273+
269274
// TEST(TestReg2Shared, operand_C_float) {
270275
// using Element = __half;
271276
// using AccType = float;

0 commit comments

Comments
 (0)