Skip to content

Commit

Permalink
Make swizzle function adaptive to floating-point numbers.
Browse files Browse the repository at this point in the history
  • Loading branch information
lcy-seso committed Jan 24, 2025
1 parent e5424b3 commit 8102e29
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 46 deletions.
23 changes: 14 additions & 9 deletions include/cell/copy/global_to_shared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ namespace tilefusion::cell::copy {
using namespace atom;
namespace tl = tile_layout;

namespace {
constexpr size_t Log2(size_t n) { return ((n < 2) ? 0 : 1 + Log2(n / 2)); }
} // namespace

/**
* @brief Load a warp tile from global memory to shared memory.
*
Expand Down Expand Up @@ -104,7 +108,11 @@ struct GlobalToSharedLoaderImpl<Global_, Shared_, BaseShape_, kRowExec_,
GlobalLayout src_in_base_tile_;

using NonSwizzled = tl::RowMajor<BaseShape::kRows, BaseShape::kCols>;
using Swizzled = SwizzledLayout<NonSwizzled, 3, 3, 3>;

static constexpr int kM = Log2(kNumPerAccess);
static constexpr int kS = Log2(traits::AccessBase<DType>::kMemTransWidth /
traits::AccessBase<DType>::kAccessInBits);
using Swizzled = SwizzledLayout<NonSwizzled, 3, kM, kS>;

using SharedLayout =
std::conditional_t<Shared::kSwizzled, Swizzled, NonSwizzled>;
Expand Down Expand Up @@ -274,15 +282,12 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape, kRowExec_,
BaseShape::kCols>;
DstBaseTilesLayout dst_base_tiles_;

// NOTE: DO NOT modify `kNumPerAccess` and `kAccessInBits` here.
// `kAccessInBits` in the storer is for tensor core's output where only two
// numbers are contiguous in memory. This ensures the parameters remain
// consistent with those used in `SharedLayoutWrapper` within the
// register-to-shared storer.
static constexpr int kAccessInBits = 2 * int(sizeof(DType) * 8);

using NonSwizzled = tl::RowMajor<BaseShape::kRows, BaseShape::kCols>;
using Swizzled = SwizzledLayout<NonSwizzled, 3, 3, 3>;

static constexpr int kM = Log2(kNumPerAccess);
static constexpr int kS = Log2(traits::AccessBase<DType>::kMemTransWidth /
traits::AccessBase<DType>::kAccessInBits);
using Swizzled = SwizzledLayout<NonSwizzled, 3, kM, kS>;

using SharedLayout =
std::conditional_t<Shared::kSwizzled, Swizzled, NonSwizzled>;
Expand Down
4 changes: 2 additions & 2 deletions include/traits/base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ struct AccessBase {
static constexpr int kAccessInBits = 128;
static constexpr int kAccessInBytes = kAccessInBits / 8;

static constexpr int kElementBits = cutlass::sizeof_bits<Element>::value;
static constexpr int kElementBits = sizeof(Element) * 8;
static constexpr int kNumPerAccess = kAccessInBits / kElementBits;

// the width of memory transaction
// the width of memory transaction, Shared memory cacheline width.
static constexpr int kMemTransWidth = 1024; // 1024 bits, 128 bytes

// The ideal number of columns for a single warp to load.
Expand Down
74 changes: 39 additions & 35 deletions tests/cpp/cell/test_g2s_load.cu
Original file line number Diff line number Diff line change
Expand Up @@ -181,41 +181,45 @@ TEST(GlobalToSharedLoad, test_row_major_half) {
}
}

// TEST(GlobalToSharedLoad, test_row_major_float) {
// {
// const bool kSwizzled = false;

// run_test_row_major<float, tl::RowMajor<1, 1>, 8, 32, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<1, 1>, 16, 16, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<1, 1>, 16, 64, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<1, 2>, 32, 64, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<1, 4>, 32, 128, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<1, 4>, 16, 128, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<4, 1>, 192, 32, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<4, 1>, 64, 32, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<2, 2>, 32, 64, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<2, 2>, 64, 128, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<2, 4>, 96, 128, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<2, 4>, 32, 128, kSwizzled>();
// }

// {
// const bool kSwizzled = true;

// run_test_row_major<float, tl::RowMajor<1, 1>, 8, 32, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<1, 1>, 16, 16, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<1, 1>, 16, 64, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<1, 2>, 32, 64, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<1, 4>, 32, 128, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<1, 4>, 16, 128, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<4, 1>, 192, 32, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<4, 1>, 64, 32, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<2, 2>, 32, 64, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<2, 2>, 64, 128, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<2, 4>, 96, 128, kSwizzled>();
// run_test_row_major<float, tl::RowMajor<2, 4>, 32, 128, kSwizzled>();
// }
// }
TEST(GlobalToSharedLoad, test_row_major_float) {
{
const bool kSwizzled = false;

run_test_row_major<float, tl::RowMajor<1, 1>, 8, 32, kSwizzled>();
run_test_row_major<float, tl::RowMajor<1, 1>, 16, 16, kSwizzled>();
run_test_row_major<float, tl::RowMajor<1, 1>, 16, 64, kSwizzled>();

run_test_row_major<float, tl::RowMajor<1, 2>, 32, 64, kSwizzled>();
run_test_row_major<float, tl::RowMajor<1, 4>, 32, 128, kSwizzled>();
run_test_row_major<float, tl::RowMajor<1, 4>, 16, 128, kSwizzled>();
run_test_row_major<float, tl::RowMajor<4, 1>, 192, 32, kSwizzled>();
run_test_row_major<float, tl::RowMajor<4, 1>, 64, 32, kSwizzled>();

run_test_row_major<float, tl::RowMajor<2, 2>, 32, 64, kSwizzled>();
run_test_row_major<float, tl::RowMajor<2, 2>, 64, 128, kSwizzled>();
run_test_row_major<float, tl::RowMajor<2, 4>, 96, 128, kSwizzled>();
run_test_row_major<float, tl::RowMajor<2, 4>, 32, 128, kSwizzled>();
}

{
const bool kSwizzled = true;

run_test_row_major<float, tl::RowMajor<1, 1>, 8, 32, kSwizzled>();
run_test_row_major<float, tl::RowMajor<1, 1>, 16, 16, kSwizzled>();
run_test_row_major<float, tl::RowMajor<1, 1>, 16, 64, kSwizzled>();

run_test_row_major<float, tl::RowMajor<1, 2>, 32, 64, kSwizzled>();
run_test_row_major<float, tl::RowMajor<1, 4>, 32, 128, kSwizzled>();
run_test_row_major<float, tl::RowMajor<1, 4>, 16, 128, kSwizzled>();
run_test_row_major<float, tl::RowMajor<4, 1>, 192, 32, kSwizzled>();
run_test_row_major<float, tl::RowMajor<4, 1>, 64, 32, kSwizzled>();

run_test_row_major<float, tl::RowMajor<2, 2>, 32, 64, kSwizzled>();
run_test_row_major<float, tl::RowMajor<2, 2>, 64, 128, kSwizzled>();
run_test_row_major<float, tl::RowMajor<2, 4>, 96, 128, kSwizzled>();
run_test_row_major<float, tl::RowMajor<2, 4>, 32, 128, kSwizzled>();
}
}

TEST(GlobalToSharedLoad, test_col_major_load) {
// FIXME(ying): temporarily disable the test to refactor the copy.
Expand Down

0 comments on commit 8102e29

Please sign in to comment.