Skip to content

Commit

Permalink
feat(cell): Add floating-point support for GMEM-to-SMEM RowMajor load…
Browse files Browse the repository at this point in the history
…ers/storers. (microsoft#46)

This PR  adds floating-point support for GMEM-to-SMEM RowMajor loaders/storers.

* add float copy shape.

* pass float non-swizzle test.

* pass more float non-swizzled testcases.

* follow comments.
  • Loading branch information
KuangjuX authored Jan 22, 2025
1 parent f154e62 commit 29ca160
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 90 deletions.
63 changes: 21 additions & 42 deletions include/cell/copy/global_to_shared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace tl = tile_layout;
/**
* @brief Load a warp tile from global memory to shared memory.
*
* This function loads a warp tile whose shape is specified by `WarpShape`
* This function loads a warp tile whose shape is specified by `BaseShape`
* from global memory to shared memory.
*
* @tparam Global_ The type of the global memory pointer.
Expand Down Expand Up @@ -93,6 +93,9 @@ struct GlobalToSharedLoaderImpl<Global_, Shared_, BaseShape_, kRowExec_,
using SwizzledBaseShape = traits::SwizzleBaseTileShape<DType>;
static constexpr int kSwizzledRows = SwizzledBaseShape::kRows;
static constexpr int kSwizzledCols = SwizzledBaseShape::kCols;
static constexpr int B = SwizzledBaseShape::B;
static constexpr int M = SwizzledBaseShape::M;
static constexpr int S = SwizzledBaseShape::S;

static constexpr int kSwizzledRowExec =
kRowExec / (kSwizzledRows / BaseShape::kRows);
Expand Down Expand Up @@ -121,7 +124,7 @@ struct GlobalToSharedLoaderImpl<Global_, Shared_, BaseShape_, kRowExec_,

using NonSwizzled =
tl::MatrixLayout<kSwizzledRows, kSwizzledCols, Shared::kRowStride, 1>;
using Swizzled = SwizzledLayout<NonSwizzled, 3, 3, 3>;
using Swizzled = SwizzledLayout<NonSwizzled, B, M, S>;

using SharedLayout =
std::conditional_t<Shared::kSwizzled, Swizzled, NonSwizzled>;
Expand Down Expand Up @@ -283,6 +286,9 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape, kRowExec_,
using SwizzledBaseShape = traits::SwizzleBaseTileShape<DType>;
static constexpr int kSwizzledRows = SwizzledBaseShape::kRows;
static constexpr int kSwizzledCols = SwizzledBaseShape::kCols;
static constexpr int B = SwizzledBaseShape::B;
static constexpr int M = SwizzledBaseShape::M;
static constexpr int S = SwizzledBaseShape::S;

static constexpr int kSwizzledRowExec =
kRowExec / (kSwizzledRows / BaseShape::kRows);
Expand Down Expand Up @@ -311,7 +317,7 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, BaseShape, kRowExec_,

using NonSwizzled =
tl::MatrixLayout<kSwizzledRows, kSwizzledCols, Shared::kRowStride, 1>;
using Swizzled = SwizzledLayout<NonSwizzled, 3, 3, 3>;
using Swizzled = SwizzledLayout<NonSwizzled, B, M, S>;
using SharedLayout =
std::conditional_t<Shared::kSwizzled, Swizzled, NonSwizzled>;
SharedLayout src_tile_;
Expand Down Expand Up @@ -393,32 +399,19 @@ struct GlobalToSharedLoader {
using DType = Shared::DType;
using WarpLayout = WarpLayout_;

// This implementation uses a fixed 16x16 `BaseShape` as the atomic data
// tile accessed by threads in a single warp that issues a single load/store
// instruction.
// FIXME(ying): uncomment the following lines to automatically infer the
// warp-level tile shape instead of using a fixed 16x16 `BaseShape`. using
// WarpShape =
// warp::WarpTileShape<DType, typename Shared::Layout, Shared::kType>;
// using WarpShape =
// warp::WarpTileShape<DType, tl::RowMajor<16, 16>, Shared::kType>;

// KuangjuX: Use `4x64` in RowMajor and `64x4` in ColMajor.
static constexpr bool kRowMajor = Shared::kType == tl::Layout::kRowMajor;
using BaseTile =
std::conditional_t<kRowMajor, tl::RowMajor<4, 64>, tl::ColMajor<64, 4>>;
using WarpShape = warp::WarpTileShape<DType, BaseTile, Shared::kType>;

static_assert(Shared::kRows % WarpShape::kRows == 0,
"Shared::kRows must be divisible by WarpShape::kRows.");
static_assert(Shared::kCols % WarpShape::kCols == 0,
"Shared::kCols must be divisible by WarpShape::kCols.");
using BaseShape =
warp::WarpBaseTileShape<DType, typename Shared::Layout, Shared::kType>;

static_assert(Shared::kRows % BaseShape ::kRows == 0,
"Shared::kRows must be divisible by BaseShape::kRows.");
static_assert(Shared::kCols % BaseShape::kCols == 0,
"Shared::kCols must be divisible by BaseShape::kCols.");

static const WarpReuse kMode = WarpReuse::kCont; // warp reuse mode
using ExecCounter = warp::ExecCounter<WarpShape, Shared, WarpLayout, kMode>;
using ExecCounter = warp::ExecCounter<BaseShape, Shared, WarpLayout, kMode>;
using GlobalOffset = warp::GlobalOffsetHelper<WarpLayout, kMode>;
using SharedOffset =
warp::SharedOffsetHelper<WarpLayout, WarpShape, Shared, kMode>;
warp::SharedOffsetHelper<WarpLayout, BaseShape, Shared, kMode>;

static constexpr int kRowExec = ExecCounter::kRowExec;
static constexpr int kColExec = ExecCounter::kColExec;
Expand All @@ -441,7 +434,7 @@ struct GlobalToSharedLoader {
int offset_dst = shared_offset_.get_warp_offset();

// Load a single warp tile from global memory to shared memory
using Loader = GlobalToSharedLoaderImpl<Global, Shared, WarpShape,
using Loader = GlobalToSharedLoaderImpl<Global, Shared, BaseShape,
kRowExec, kColExec>;

Loader loader;
Expand All @@ -459,22 +452,8 @@ struct SharedToGlobalStorer {
using DType = Shared::DType;
using WarpLayout = WarpLayout_;

// FIXME(ying): automatically infer the warp-level tile shape instead
// of using a fixed `BaseShape`.
// using WarpShape =
// warp::WarpTileShape<DType, typename Shared::Layout, Shared::kType>;

// FIXME(ying): uncomment the following lines to automatically infer the
// warp-level tile shape instead of using a fixed 16x16 `BaseShape`.
// using BaseShape =
// warp::WarpTileShape<DType, tl::RowMajor<16, 16>, Shared::kType>;

// KuangjuX: Use `4x64` in RowMajor and `64x4` in ColMajor.

static constexpr bool kRowMajor = Shared::kType == tl::Layout::kRowMajor;
using BaseTile =
std::conditional_t<kRowMajor, tl::RowMajor<4, 64>, tl::ColMajor<64, 4>>;
using BaseShape = warp::WarpTileShape<DType, BaseTile, Shared::kType>;
using BaseShape =
warp::WarpBaseTileShape<DType, typename Shared::Layout, Shared::kType>;

static_assert(Shared::kRows % BaseShape::kRows == 0,
"Shared::kRows must be divisible by BaseShape::kRows.");
Expand Down
6 changes: 3 additions & 3 deletions include/cell/copy/warp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ struct ExecCounter {
/// the entire tile. The final warp tile shape is multiple of this atomic
/// shape.
template <typename DType, typename TileLayout, const tl::Layout kType>
struct WarpTileShape;
struct WarpBaseTileShape;

template <typename DType, typename TileLayout>
struct WarpTileShape<DType, TileLayout, tl::Layout::kRowMajor> {
struct WarpBaseTileShape<DType, TileLayout, tl::Layout::kRowMajor> {
using AccessInfo = traits::AccessBase<DType>;

// In a row-major layout, columns are the contiguous dimension in memory. We
Expand Down Expand Up @@ -222,7 +222,7 @@ struct WarpTileShape<DType, TileLayout, tl::Layout::kRowMajor> {
};

template <typename DType, typename TileLayout>
struct WarpTileShape<DType, TileLayout, tl::Layout::kColMajor> {
struct WarpBaseTileShape<DType, TileLayout, tl::Layout::kColMajor> {
using AccessInfo = traits::AccessBase<DType>;

// In a column-major layout, columns are the contiguous dimension in memory.
Expand Down
29 changes: 8 additions & 21 deletions include/traits/base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,6 @@ struct BaseTileShape {
static constexpr int kNumel = kRows * kCols;
};

/**
* @brief Single Warp load shape from global memory to shared memory.
*/
template <typename Element>
requires BaseType<Element>
struct GMemCopyShape;

template <>
struct GMemCopyShape<__half> {
static constexpr int kRows = 64;
static constexpr int kCols = 4;
static constexpr int kNumel = kRows * kCols;
};

template <>
struct GMemCopyShape<float> {
static constexpr int kRows = 32;
static constexpr int kCols = 4;
static constexpr int kNumel = kRows * kCols;
};

/**
* @brief The base tile shape for Swizzle<3, 3, 3>.
*/
Expand All @@ -85,6 +64,10 @@ struct SwizzleBaseTileShape<__half> {
static constexpr int kRows = 8;
static constexpr int kCols = 64;
static constexpr int kNumel = kRows * kCols;

static constexpr int B = 3;
static constexpr int M = 3;
static constexpr int S = 3;
};

template <>
Expand All @@ -94,6 +77,10 @@ struct SwizzleBaseTileShape<float> {
static constexpr int kRows = 8;
static constexpr int kCols = 32;
static constexpr int kNumel = kRows * kCols;

static constexpr int B = 3;
static constexpr int M = 2;
static constexpr int S = 3;
};

} // namespace tilefusion::traits
48 changes: 32 additions & 16 deletions tests/cpp/cell/test_atomic_warp_tile_shape.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ TEST(InferAtomicWarpTile, test1_half_row_major) {

{ // atomic warp shape: 32x8, thread layout: 32x1
using Layout = tl::RowMajor<128, 8>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;

EXPECT_EQ(WarpTile::kRows, 32);
EXPECT_EQ(WarpTile::kCols, 8);
Expand All @@ -26,7 +27,8 @@ TEST(InferAtomicWarpTile, test1_half_row_major) {

{ // atomic warp shape: 16x16, thread layout: 16x2
using Layout = tl::RowMajor<64, 16>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;

EXPECT_EQ(WarpTile::kRows, 16);
EXPECT_EQ(WarpTile::kCols, 16);
Expand All @@ -37,7 +39,8 @@ TEST(InferAtomicWarpTile, test1_half_row_major) {

{ // atomic warp shape: 8x32, thread layout: 8x4
using Layout = tl::RowMajor<16, 32>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;

EXPECT_EQ(WarpTile::kRows, 8);
EXPECT_EQ(WarpTile::kCols, 32);
Expand All @@ -48,7 +51,8 @@ TEST(InferAtomicWarpTile, test1_half_row_major) {

{ // atomic warp shape: 4x64, thread layout: 4x8
using Layout = tl::RowMajor<128, 128>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;

EXPECT_EQ(WarpTile::kRows, 4);
EXPECT_EQ(WarpTile::kCols, 64);
Expand All @@ -63,7 +67,8 @@ TEST(InferAtomicWarpTile, test2_half_column_major) {

{ // atomic warp shape: 8x32, thread layout: 1x32
using Layout = tl::ColMajor<8, 128>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;

EXPECT_EQ(WarpTile::kRows, 8);
EXPECT_EQ(WarpTile::kCols, 32);
Expand All @@ -74,7 +79,8 @@ TEST(InferAtomicWarpTile, test2_half_column_major) {

{ // atomic warp shape: 16x16, thread layout: 2x16
using Layout = tl::ColMajor<16, 64>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;

EXPECT_EQ(WarpTile::kRows, 16);
EXPECT_EQ(WarpTile::kCols, 16);
Expand All @@ -85,7 +91,8 @@ TEST(InferAtomicWarpTile, test2_half_column_major) {

{ // atomic warp shape: 32x8, thread layout: 4x8
using Layout = tl::ColMajor<32, 16>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;

EXPECT_EQ(WarpTile::kRows, 32);
EXPECT_EQ(WarpTile::kCols, 8);
Expand All @@ -96,7 +103,8 @@ TEST(InferAtomicWarpTile, test2_half_column_major) {

{ // atomic warp shape: 64x4, thread layout: 8x4
using Layout = tl::ColMajor<128, 128>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;

EXPECT_EQ(WarpTile::kRows, 64);
EXPECT_EQ(WarpTile::kCols, 4);
Expand All @@ -111,7 +119,8 @@ TEST(InferAtomicWarpTile, test3_float_row_major) {

{ // atomic warp shape: 32x4, thread layout: 32x1
using Layout = tl::RowMajor<128, 4>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;

EXPECT_EQ(WarpTile::kRows, 32);
EXPECT_EQ(WarpTile::kCols, 4);
Expand All @@ -122,7 +131,8 @@ TEST(InferAtomicWarpTile, test3_float_row_major) {

{ // atomic warp shape: 16x8, thread layout: 16x2
using Layout = tl::RowMajor<64, 8>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;

EXPECT_EQ(WarpTile::kRows, 16);
EXPECT_EQ(WarpTile::kCols, 8);
Expand All @@ -133,7 +143,8 @@ TEST(InferAtomicWarpTile, test3_float_row_major) {

{ // atomic warp shape: 8x16, thread layout: 8x4
using Layout = tl::RowMajor<16, 16>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;

EXPECT_EQ(WarpTile::kRows, 8);
EXPECT_EQ(WarpTile::kCols, 16);
Expand All @@ -144,7 +155,8 @@ TEST(InferAtomicWarpTile, test3_float_row_major) {

{ // atomic warp shape: 4x32, thread layout: 4x8
using Layout = tl::RowMajor<128, 128>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kRowMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kRowMajor>;

EXPECT_EQ(WarpTile::kRows, 4);
EXPECT_EQ(WarpTile::kCols, 32);
Expand All @@ -159,7 +171,8 @@ TEST(InferAtomicWarpTile, test4_float_column_major) {

{ // atomic warp shape: 4x32, thread layout: 1x32
using Layout = tl::ColMajor<4, 128>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;

EXPECT_EQ(WarpTile::kRows, 4);
EXPECT_EQ(WarpTile::kCols, 32);
Expand All @@ -170,7 +183,8 @@ TEST(InferAtomicWarpTile, test4_float_column_major) {

{ // atomic warp shape: 8x16, thread layout: 2x16
using Layout = tl::ColMajor<8, 64>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;

EXPECT_EQ(WarpTile::kRows, 8);
EXPECT_EQ(WarpTile::kCols, 16);
Expand All @@ -181,7 +195,8 @@ TEST(InferAtomicWarpTile, test4_float_column_major) {

{ // atomic warp shape: 16x8, thread layout: 4x8
using Layout = tl::ColMajor<16, 32>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;

EXPECT_EQ(WarpTile::kRows, 16);
EXPECT_EQ(WarpTile::kCols, 8);
Expand All @@ -192,7 +207,8 @@ TEST(InferAtomicWarpTile, test4_float_column_major) {

{ // atomic warp shape: 4x32, thread layout: 8x4
using Layout = tl::ColMajor<128, 128>;
using WarpTile = WarpTileShape<DType, Layout, tl::Layout::kColMajor>;
using WarpTile =
WarpBaseTileShape<DType, Layout, tl::Layout::kColMajor>;

EXPECT_EQ(WarpTile::kRows, 32);
EXPECT_EQ(WarpTile::kCols, 4);
Expand Down
Loading

0 comments on commit 29ca160

Please sign in to comment.