Skip to content

Commit

Permalink
fix shared tile iterator.
Browse files Browse the repository at this point in the history
  • Loading branch information
KuangjuX committed Feb 5, 2025
1 parent f49b09b commit 262211b
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 15 deletions.
3 changes: 0 additions & 3 deletions include/types/global.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ struct GlobalTile {
static constexpr int kRows = tl::num_rows<Layout>;
static constexpr int kCols = tl::num_cols<Layout>;

static_assert(kCols % 64 == 0,
"The number of columns must be a multiple of 64.");

static constexpr int kRowStride = tl::row_stride<Layout>;
static constexpr int kColStride = tl::col_stride<Layout>;

Expand Down
3 changes: 0 additions & 3 deletions include/types/shared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,6 @@ class SharedTile {
static constexpr int kRows = tl::num_rows<Layout>;
static constexpr int kCols = tl::num_cols<Layout>;

static_assert(kCols % 64 == 0,
"The number of columns must be a multiple of 64.");

static constexpr int kRowStride = tl::row_stride<Layout>;
static constexpr int kColStride = tl::col_stride<Layout>;

Expand Down
18 changes: 9 additions & 9 deletions tests/cpp/cell/test_swizzled_copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,21 @@ __global__ void swizzled_copy(const Element* data, G2S1& g2s,
s2r(s_swizzled_tiles(i), r_tile_swizzled);
__syncthreads();

// #ifdef DEBUG
#ifdef DEBUG
if (thread(0)) {
// printf("\niteration [%d, %d]\n", k, i);
// s_tiles(i).dump_value();
printf("\niteration [%d, %d]\n", k, i);
s_tiles(i).dump_value();

// printf("\ns_swizzled_tiles:\n");
// s_swizzled_tiles(i).dump_value();
printf("\ns_swizzled_tiles:\n");
s_swizzled_tiles(i).dump_value();

printf("r_tile:\n");
r_tile.dump_value();

// printf("\nr_tile_swizzled:\n");
// r_tile_swizzled.dump_value();
printf("\nr_tile_swizzled:\n");
r_tile_swizzled.dump_value();
}
// #endif
#endif
check_results<Reg, Element>(r_tile, r_tile_swizzled, Reg::kRows,
Reg::kCols);
}
Expand Down Expand Up @@ -145,7 +145,7 @@ void run_test_rowmajor() {
using Element = __half;
thrust::host_vector<Element> hA(numel);
for (int i = 0; i < hA.size(); ++i) {
hA[i] = static_cast<Element>(i % 2048);
hA[i] = static_cast<Element>(i);
}
thrust::device_vector<Element> dA = hA;

Expand Down

0 comments on commit 262211b

Please sign in to comment.