Skip to content

Commit

Permalink
fix(cell): Move the definition of BaseTile into types to facilitate f…
Browse files Browse the repository at this point in the history
…urther refactoring. (microsoft#53)

This pull request doesn't introduce any new functionality; rather, it
makes the following adjustments:

1. It moves the definition of BaseTile to the types file to facilitate
upcoming refactoring.
2. It initializes the document for the shared memory layout.
  • Loading branch information
lcy-seso authored Feb 13, 2025
1 parent c75a343 commit 280ecd8
Show file tree
Hide file tree
Showing 11 changed files with 294 additions and 201 deletions.
27 changes: 27 additions & 0 deletions docs/tiles_in_shared_memory.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
## Data Layout for Efficient Shared Memory Access

### A Base Tile

A `BaseTile` is a two-dimensional collection of data accessed cooperatively by threads within a single warp, with each thread issuing a single data access instruction.

Let’s consider some specific examples. Suppose each thread accesses 128-bit data in a single access, and the threads are arranged within the warp in a row-major fashion, where threads along the rows have consecutive thread indices.

If the data is in ***half-precision*** floating-point format:

- When the threads in a warp are arranged in a $4 \times 8$ configuration, the `BaseTile` has dimensions of $4 \times 64$.
- When the threads in a warp are arranged in an $8 \times 4$ configuration, the `BaseTile` has dimensions of $8 \times 32$.
- When the threads in a warp are arranged in a $16 \times 2$ configuration, the `BaseTile` has dimensions of $16 \times 16$.

Now, suppose the data is in ***single-precision*** floating-point format:

- When the threads in a warp are arranged in a $4 \times 8$ configuration, the `BaseTile` has dimensions of $4 \times 32$.
- When the threads in a warp are arranged in an $8 \times 4$ configuration, the `BaseTile` has dimensions of $8 \times 16$.
- When the threads in a warp are arranged in a $16 \times 2$ configuration, the `BaseTile` has dimensions of $16 \times 8$.

A keen observer may notice that the largest dimension of a `BaseTile` never exceeds 1024 bits. This is not coincidental; it is a result of several hardware parameters related to global and shared memory access. Global memory traffic is routed through the data caches (the L1 and/or L2 caches). An L1 cache line is 1024 bits, which also corresponds to the maximum memory transaction size. Additionally, shared memory consists of 32 banks, each with a width of 4 bytes, collectively amounting to 1024 bits. This alignment enhances the efficiency of data transfer between global and shared memory.

### Storing Tiles in Shared Memory

To ensure an efficient access pattern, we need to impose a constraint by assuming that each thread accesses 128-bit data, which is the maximum width of a vectorized access instruction. Consequently, the entire warp accesses $4 \times 128$ bytes of data. It is known that 128 bytes is the largest transaction size. When more than 128 bytes of data per warp are loaded or stored, the GPU does not issue a single transaction but divides the data access into four transactions. Furthermore, bank conflicts occur per transaction.

Our objective is to avoid bank conflicts when loading data tiles from or storing data tiles to shared memory.
5 changes: 2 additions & 3 deletions include/cell/copy/global_to_shared.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "cell/copy/mod.hpp"
Expand Down Expand Up @@ -504,7 +503,7 @@ struct GlobalToSharedLoader {

using WarpShape = TileShape<Shared::kRows / WarpLayout::kRows,
Shared::kCols / WarpLayout::kCols>;
using BaseShape = warp::WarpBaseTileShape<DType, WarpShape, Shared::kType>;
using BaseShape = WarpBaseTileShape<DType, WarpShape, Shared::kType>;

static_assert(Shared::kRows % BaseShape ::kRows == 0,
"Shared::kRows must be divisible by BaseShape::kRows.");
Expand Down Expand Up @@ -558,7 +557,7 @@ struct SharedToGlobalStorer {

using WarpShape = TileShape<Shared::kRows / WarpLayout::kRows,
Shared::kCols / WarpLayout::kCols>;
using BaseShape = warp::WarpBaseTileShape<DType, WarpShape, Shared::kType>;
using BaseShape = WarpBaseTileShape<DType, WarpShape, Shared::kType>;

static_assert(Shared::kRows % BaseShape::kRows == 0,
"Shared::kRows must be divisible by BaseShape::kRows.");
Expand Down
182 changes: 54 additions & 128 deletions include/cell/copy/warp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,44 @@ DEVICE int warp_col_id() {
}
}

template <typename BaseTile_, typename Tile_, typename WarpLayout_,
template <const int kSharedRows, const int kWarpRows,
const WarpReuse kMode = WarpReuse::kCont>
HOST_DEVICE constexpr int warp_tile_rows() {
if constexpr (kMode == WarpReuse::kCont) {
return kSharedRows / kWarpRows;
} else if constexpr (kMode == WarpReuse::kRowReuseCont) {
return kSharedRows / kWarpRows;
} else if constexpr (kMode == WarpReuse::kColReuseCont) {
return kSharedRows;
}
return -1;
}

template <const int kSharedCols, const int kWarpCols,
const WarpReuse kMode = WarpReuse::kCont>
HOST_DEVICE constexpr int warp_tile_cols() {
if constexpr (kMode == WarpReuse::kCont) {
return kSharedCols / kWarpCols;
} else if constexpr (kMode == WarpReuse::kRowReuseCont) {
return kSharedCols;
} else if constexpr (kMode == WarpReuse::kColReuseCont) {
return kSharedCols / kWarpCols;
}
return -1;
}

template <typename BaseShape_, typename Tile_, typename WarpLayout_,
const WarpReuse kMode_>
struct ExecCounter {
using BaseTile = BaseTile_;
using BaseShape = BaseShape_;
using Tile = Tile_;

static_assert(
Tile::kCols % BaseTile::kCols == 0,
Tile::kCols % BaseShape::kCols == 0,
"The number of shared memory columns must be divisible by the base "
"tile column.\n");
static_assert(
Tile::kRows % BaseTile::kRows == 0,
Tile::kRows % BaseShape::kRows == 0,
"The current implementation requires that the number of shared "
"memory rows be divisible by the base tile row.\n");

Expand All @@ -167,9 +193,9 @@ struct ExecCounter {
// load the shared memory rows. Therefore, `row_exec` is not divided
// by warps_per_row.
case WarpReuse::kColReuseCont:
return Tile::kRows / BaseTile::kRows;
return Tile::kRows / BaseShape::kRows;
default: // Cont, RowReuseCont hit this case.
return Tile::kRows / BaseTile::kRows / kWarpsPerRow;
return Tile::kRows / BaseShape::kRows / kWarpsPerRow;
}
}

Expand All @@ -179,116 +205,16 @@ struct ExecCounter {
// the shared memory columns. Therefore, `col_exec` is not divided
// by `warps_per_col`.
case WarpReuse::kRowReuseCont:
return Tile::kCols / BaseTile::kCols;
return Tile::kCols / BaseShape::kCols;
default: // Cont, ColReuseCont hit this case.
return Tile::kCols / BaseTile::kCols / kWarpsPerCol;
return Tile::kCols / BaseShape::kCols / kWarpsPerCol;
}
}

static constexpr int kRowExec = row_exec_count();
static constexpr int kColExec = col_exec_count();
};

/// @brief Determine the automatic shape of a single warp based on the shape of
/// the entire tile. The final warp tile shape is multiple of this atomic
/// shape.
template <typename DType, typename TileShape, const tl::Layout kType>
struct WarpBaseTileShape;

template <typename DType, typename TileShape>
struct WarpBaseTileShape<DType, TileShape, tl::Layout::kRowMajor> {
using AccessInfo = traits::AccessBase<DType>;

static constexpr int kTileRows = dim_size<0, TileShape>;
static constexpr int kTileCols = dim_size<1, TileShape>;

// In a row-major layout, columns are the contiguous dimension in memory. We
// enforce the use of 128-bit vectorized instructions for data loading by a
// single thread. This implies that the minimum number of columns should be
// at least 128 bits.
static constexpr int kMinCols =
AccessInfo::kAccessInBits / (sizeof(DType) * 8);

static_assert(kTileCols >= kMinCols, "The number of columns is too small.");

static_assert(kTileCols < AccessInfo::kExpectedSize ||
(kTileCols >= AccessInfo::kExpectedSize &&
kTileCols % AccessInfo::kExpectedSize == 0),
"The current implementation requires that the number of "
"columns of the tile be divisible by the cache line width.");

static constexpr int kCols = kTileCols >= AccessInfo::kExpectedSize
? AccessInfo::kExpectedSize
: kTileCols;

// number of columns in a warp
static constexpr int kColThreads = kCols / AccessInfo::kNumPerAccess;
static_assert(WARP_SIZE % kColThreads == 0,
"Fail to infer warp thread layout.");
static constexpr int kRowThreads = WARP_SIZE / kColThreads;

static constexpr int kRows = kRowThreads;
static_assert(kTileRows % kRowThreads == 0,
"The number of rows of the tile isn't evenly divisible by "
"the number of threads in a column.");

static constexpr int kNumel = kRows * kCols;

using WarpThreadLayout = tl::RowMajor<kRowThreads, kColThreads>;
};

template <typename DType, typename TileShape>
struct WarpBaseTileShape<DType, TileShape, tl::Layout::kColMajor> {
using AccessInfo = traits::AccessBase<DType>;

static constexpr int kTileRows = dim_size<0, TileShape>;
static constexpr int kTileCols = dim_size<1, TileShape>;

// In a column-major layout, columns are the contiguous dimension in memory.
// We enforce the use of 128-bit vectorized instructions for data loading by
// a single thread. This implies that the minimum number of columns should
// be at least 128 bits.
static constexpr int kMinRows =
AccessInfo::kAccessInBits / (sizeof(DType) * 8);

static_assert(kTileRows >= kMinRows, "The number of rows is too small.");

static_assert(kTileRows < AccessInfo::kExpectedSize ||
(kTileRows >= AccessInfo::kExpectedSize &&
kTileRows % AccessInfo::kExpectedSize == 0),
"The current implementation requires that the number of "
"rows of the tile be divisible by the cache line width.");

static constexpr int kRows = kTileRows >= AccessInfo::kExpectedSize
? AccessInfo::kExpectedSize
: kTileRows;

// number of rows in a warp
static constexpr int kRowThreads = kRows / AccessInfo::kNumPerAccess;
static_assert(WARP_SIZE % kRowThreads == 0,
"Fail to infer warp thread layout.");
static constexpr int kColThreads = WARP_SIZE / kRowThreads;

static constexpr int kCols = kColThreads;
static_assert(kTileCols % kColThreads == 0,
"The number of columns of the tile isn't evenly divisible by "
"the number of threads in a row.");

static constexpr int kNumel = kRows * kCols;

using WarpThreadLayout = tl::ColMajor<kRowThreads, kColThreads>;
};

/// @brief Pretty printer for the static shape information of a
/// `WarpBaseTileShape`. Note: This printer function works ONLY on the
/// host.
template <typename DType, typename TileShape, const tl::Layout kType>
static HOST std::ostream& operator<<(
std::ostream& out, const WarpBaseTileShape<DType, TileShape, kType>& tile) {
BaseTilePrettyPrinter::print(out, tile);
return out;
}

template <typename WarpLayout_, const WarpReuse kMode_>
struct GlobalOffsetHelper {
static constexpr WarpReuse kMode = kMode_;
Expand Down Expand Up @@ -347,62 +273,62 @@ struct GlobalOffsetHelper {
* not correctly reveal the physical layout of data in memory. This requires
* further special treatment.
*/
template <typename WarpLayout, typename WarpShape, typename Shared,
template <typename WarpLayout, typename BaseShape, typename Shared,
const WarpReuse kMode, const tl::Layout kType = Shared::kType,
const bool kIsSharedLayout = IsSharedLayout<Shared>>
struct SharedOffsetHelper;

template <typename WarpLayout_, typename WarpShape_, typename Shared_,
template <typename WarpLayout_, typename BaseShape_, typename Shared_,
const WarpReuse kMode_>
struct SharedOffsetHelper<WarpLayout_, WarpShape_, Shared_, kMode_,
struct SharedOffsetHelper<WarpLayout_, BaseShape_, Shared_, kMode_,
tl::Layout::kRowMajor, false> {
DEVICE int get_warp_offset() {
// TODO(KuangjuX): hotfix this.
return warp_row_id<WarpLayout>() * kRowStride * WarpShape::kRows *
return warp_row_id<WarpLayout>() * kRowStride * BaseShape::kRows *
Shared::kCols +
warp_col_id<WarpLayout>() * kColStride * WarpShape::kCols;
warp_col_id<WarpLayout>() * kColStride * BaseShape::kCols;
}

private:
using Shared = Shared_;
using WarpLayout = WarpLayout_;
using WarpShape = WarpShape_;
using BaseShape = BaseShape_;
static constexpr WarpReuse kMode = kMode_;

constexpr static int kTilePerRow = Shared::kRows / WarpShape::kRows;
constexpr static int kTilePerCol = Shared::kCols / WarpShape::kCols;
constexpr static int kTilePerRow = Shared::kRows / BaseShape::kRows;
constexpr static int kTilePerCol = Shared::kCols / BaseShape::kCols;

// TODO(KuangjuX): hotfix this.
constexpr static int kRowStride = kTilePerRow / tl::num_rows<WarpLayout>;
constexpr static int kColStride = kTilePerCol / tl::num_cols<WarpLayout>;
};

template <typename WarpLayout_, typename WarpShape_, typename Shared_,
template <typename WarpLayout_, typename BaseShape_, typename Shared_,
const WarpReuse kMode_>
struct SharedOffsetHelper<WarpLayout_, WarpShape_, Shared_, kMode_,
struct SharedOffsetHelper<WarpLayout_, BaseShape_, Shared_, kMode_,
tl::Layout::kColMajor, false> {
DEVICE int get_warp_offset() {
return warp_row_id<WarpLayout>() * kRowStride * WarpShape::kRows +
warp_col_id<WarpLayout>() * kColStride * WarpShape::kCols *
return warp_row_id<WarpLayout>() * kRowStride * BaseShape::kRows +
warp_col_id<WarpLayout>() * kColStride * BaseShape::kCols *
Shared::kRows;
}

private:
using Shared = Shared_;
using WarpLayout = WarpLayout_;
using WarpShape = WarpShape_;
using BaseShape = BaseShape_;
static constexpr WarpReuse kMode = kMode_;

constexpr static int kTilePerRow = Shared::kRows / WarpShape::kRows;
constexpr static int kTilePerCol = Shared::kCols / WarpShape::kCols;
constexpr static int kTilePerRow = Shared::kRows / BaseShape::kRows;
constexpr static int kTilePerCol = Shared::kCols / BaseShape::kCols;

constexpr static int kRowStride = kTilePerRow / tl::num_rows<WarpLayout>;
constexpr static int kColStride = kTilePerCol / tl::num_cols<WarpLayout>;
};

template <typename WarpLayout_, typename WarpShape_, typename Shared_,
template <typename WarpLayout_, typename BaseShape_, typename Shared_,
const WarpReuse kMode_, const tl::Layout kType>
struct SharedOffsetHelper<WarpLayout_, WarpShape_, Shared_, kMode_, kType,
struct SharedOffsetHelper<WarpLayout_, BaseShape_, Shared_, kMode_, kType,
true> {
using WarpLayout = WarpLayout_;

Expand All @@ -412,11 +338,11 @@ struct SharedOffsetHelper<WarpLayout_, WarpShape_, Shared_, kMode_, kType,

private:
using Shared = Shared_;
using WarpShape = WarpShape_;
using BaseShape = BaseShape_;
static constexpr WarpReuse kMode = kMode_;

constexpr static int kTilePerRow = Shared::kCols / WarpShape::kCols;
constexpr static int kTilePerCol = Shared::kRows / WarpShape::kRows;
constexpr static int kTilePerRow = Shared::kCols / BaseShape::kCols;
constexpr static int kTilePerCol = Shared::kRows / BaseShape::kRows;

constexpr static int kTilePerWarpRow =
kTilePerRow / tl::num_cols<WarpLayout>;
Expand Down
6 changes: 4 additions & 2 deletions include/traits/base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ template <typename Element>
struct AccessBase {
// the maximal width of vectorized access.
static constexpr int kAccessInBits = 128;
static constexpr int kElementBits = cutlass::sizeof_bits<Element>::value;
static constexpr int kAccessInBytes = kAccessInBits / 8;

static constexpr int kElementBits = sizeof(Element) * 8;
static constexpr int kNumPerAccess = kAccessInBits / kElementBits;

// the width of memory transaction
// the width of memory transaction, Shared memory cacheline width.
static constexpr int kMemTransWidth = 1024; // 1024 bits, 128 bytes

// The ideal number of columns for a single warp to load.
Expand Down
Loading

0 comments on commit 280ecd8

Please sign in to comment.