Skip to content

Commit

Permalink
feat(types): Implement Swizzle functor and SwizzleLayout (microso…
Browse files Browse the repository at this point in the history
…ft#41)

* Add a swizzle functor.

* Update swizzle.

* Add SwizzleLayout.

* Add swizzle test.

* Update swizzle test.

* Add assert for swizzle.

* Update swizzle layout test.

* Remove unused codes.

* Add a nested basetile swizzle layout.

* Update Nested Swizzle Layout test.

* small fix.

* follow comments.
  • Loading branch information
KuangjuX authored Jan 18, 2025
1 parent 8a077e7 commit 953b6e0
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 3 deletions.
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
"random": "cpp",
"limits": "cpp",
"semaphore": "cpp",
"regex": "cpp"
"regex": "cpp",
"tuple": "cpp"
},
"gotoSymbolStack.currentStackPosition": 0,
"gotoSymbolStack.maxStackPosition": 0,
Expand Down
4 changes: 2 additions & 2 deletions include/types/layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ struct SharedLayout {

static constexpr Layout kType = kType_;

DEVICE int operator()(int i, int j) const {
HOST_DEVICE int operator()(int i, int j) const {
int tile_x = i / BaseShape::kRows;
int tile_y = j / BaseShape::kCols;

Expand Down Expand Up @@ -288,7 +288,7 @@ struct MatrixLayout {
static constexpr Layout kType =
kColStride == 1 ? Layout::kRowMajor : Layout::kColMajor;

DEVICE int operator()(int i, int j) const {
HOST_DEVICE int operator()(int i, int j) const {
return i * kRowStride + j * kColStride;
}
};
Expand Down
1 change: 1 addition & 0 deletions include/types/mod.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@
#include "types/register.hpp"
#include "types/shared.hpp"
#include "types/shared_tile_iterator.hpp"
#include "types/swizzle.hpp"
#include "types/tile_shape.hpp"
84 changes: 84 additions & 0 deletions include/types/swizzle.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "cuda_utils.hpp"

#include <cassert>

namespace tilefusion::cell {
/**
* @brief A swizzle functor.
* A Swizzle can handle 2^B * 2^S * 2^M elements.
*/
template <const int kB, const int kM, const int kS>
struct Swizzle {
static constexpr int Bbits = kB;
static constexpr int Mbits = kM;
static constexpr int Sbits = kS;
/**
* @brief Apply the swizzle to an index.
*
* @param idx The index in a swizzle chunk of 2^B * 2^S * 2^M elements.
* @return The swizzled index.
*/
HOST_DEVICE int apply(int idx) const {
// | Bbits | Sbits | Mbits |
// Mbits as mask for the lower bits.

assert(idx < (1 << (Bbits + Mbits + Sbits)));

int bs = idx >> Mbits;
// (b, s) as a 2d coordinate.
int y = bs & ((1 << Sbits) - 1);
int x = bs >> Sbits;

int swizzled_y = x ^ y;

// Use swizzled_y instead of y and build swizzled idx.
return (x << (Mbits + Sbits)) | (swizzled_y << Mbits) |
(idx & ((1 << Mbits) - 1));
}
};

/**
* @brief Swizzled Layout.
*
* @tparam Layout_ The layout to swizzle.
* @tparam kB The number of bits for B.
* @tparam kM The number of bits for M.
* @tparam kS The number of bits for S.
*/
template <typename Layout_, const int kB = 3, const int kM = 3,
const int kS = 3>
struct SwizzledLayout {
static constexpr int Bbits = kB;
static constexpr int Mbits = kM;
static constexpr int Sbits = kS;

using Layout = Layout_;
using Swizzle = Swizzle<Bbits, Mbits, Sbits>;

/**
* @brief Apply the swizzle in a layout.
*
* @param x Row dimension index, with a total of 2^B rows.
* @param y Column dimension index, with a total of 2^S * 2^M columns.
*/
HOST_DEVICE auto operator()(int x, int y) const {
int idx = (x << (Mbits + Sbits)) | y;

assert(idx < (1 << (Bbits + Mbits + Sbits)));

int swizzled_idx = swizzle_.apply(idx);
int swizzled_x = swizzled_idx >> (Mbits + Sbits);
int swizzled_y = swizzled_idx & ((1 << (Mbits + Sbits)) - 1);
return Layout{}(swizzled_x, swizzled_y);
}

private:
Swizzle swizzle_;
};

} // namespace tilefusion::cell
123 changes: 123 additions & 0 deletions tests/cpp/types/test_swizzle.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "common/test_utils.hpp"
#include "types/mod.hpp"

namespace tilefusion::testing {
using namespace cell;
namespace tl = tile_layout;

int flatten(int x, int y, int width) { return x * width + y; }

template <const int kB, const int kM, const int kS>
int swizzle_ref(int x, int y) {
int b = x;
int s = y >> kM;

int swizzled_s = b ^ s;
int swizzle_idx =
(b << (kM + kS)) | (swizzled_s << kM) | (y & ((1 << kM) - 1));

return swizzle_idx;
}

template <const int kB, const int kM, const int kS>
int2 test_swizzle(int x, int y) {
Swizzle<kB, kM, kS> swizzle;
int idx = flatten(x, y, 1 << (kS + kM));
int swizzled_idx = swizzle.apply(idx);

int ref_swizzled_idx = swizzle_ref<kB, kM, kS>(x, y);

#ifdef DEBUG
printf("idx: %d, swizzled_idx: %d, ref_swizzled_idx: %d\n", idx,
swizzled_idx, ref_swizzled_idx);
#endif

return make_int2(swizzled_idx, ref_swizzled_idx);
}

TEST(TESTSwizzle, test_swizzle_apply) {
const int kB = 3;
const int kM = 3;
const int kS = 3;

int2 swizzled_idx_0_0 = test_swizzle<kB, kM, kS>(0, 0);
int2 swizzled_idx_1_0 = test_swizzle<kB, kM, kS>(1, 0);
int2 swizzled_idx_1_4 = test_swizzle<kB, kM, kS>(1, 4);
int2 swizzled_idx_2_0 = test_swizzle<kB, kM, kS>(2, 0);
int2 swizzled_idx_2_4 = test_swizzle<kB, kM, kS>(2, 4);

EXPECT_EQ(swizzled_idx_0_0.x, swizzled_idx_0_0.y);
EXPECT_EQ(swizzled_idx_1_0.x, swizzled_idx_1_0.y);
EXPECT_EQ(swizzled_idx_1_4.x, swizzled_idx_1_4.y);
EXPECT_EQ(swizzled_idx_2_0.x, swizzled_idx_2_0.y);
EXPECT_EQ(swizzled_idx_2_4.x, swizzled_idx_2_4.y);
}

TEST(TESTSwizzle, test_naive_swizzle_layout) {
const int kB = 3;
const int kM = 3;
const int kS = 3;

const int kRows = 1 << kB;
const int kCols = 1 << (kM + kS);

using NaiveRowMajorLayout = tl::RowMajor<kRows, kCols>;
using NaiveSwizzledRowMajorLayout =
SwizzledLayout<NaiveRowMajorLayout, kB, kM, kS>;

NaiveSwizzledRowMajorLayout naive_row_major_swizzled_layout;

EXPECT_EQ((naive_row_major_swizzled_layout(0, 0)),
(swizzle_ref<kB, kM, kS>(0, 0)));
EXPECT_EQ((naive_row_major_swizzled_layout(1, 0)),
(swizzle_ref<kB, kM, kS>(1, 0)));
EXPECT_EQ((naive_row_major_swizzled_layout(1, 4)),
(swizzle_ref<kB, kM, kS>(1, 4)));
EXPECT_EQ((naive_row_major_swizzled_layout(2, 0)),
(swizzle_ref<kB, kM, kS>(2, 0)));
EXPECT_EQ((naive_row_major_swizzled_layout(2, 4)),
(swizzle_ref<kB, kM, kS>(2, 4)));
}

TEST(TESTSwizzle, test_nested_basetile_swizzle_layout) {
const int kB = 3;
const int kM = 3;
const int kS = 3;

const int kRows = 1 << kB;
const int kCols = 1 << (kM + kS);

using NestedBaseTileLayout =
tl::detail::SharedLayout<kRows, kCols, kCols * 16, 16,
tl::Layout::kRowMajor>;
using NestedBaseTileSwizzledLayout =
SwizzledLayout<NestedBaseTileLayout, kB, kM, kS>;

NestedBaseTileLayout nested_base_tile_layout;
NestedBaseTileSwizzledLayout nested_base_tile_swizzled_layout;

int idx_0_0 = nested_base_tile_layout(0, 0);
int idx_1_0 = nested_base_tile_layout(1, 0);
int idx_1_4 = nested_base_tile_layout(1, 4);
int idx_2_0 = nested_base_tile_layout(2, 0);
int idx_2_4 = nested_base_tile_layout(2, 4);

int swizzled_idx_0_0 = nested_base_tile_swizzled_layout(0, 0);
int swizzled_idx_1_0 = nested_base_tile_swizzled_layout(1, 0);
int swizzled_idx_1_4 = nested_base_tile_swizzled_layout(1, 4);
int swizzled_idx_2_0 = nested_base_tile_swizzled_layout(2, 0);
int swizzled_idx_2_4 = nested_base_tile_swizzled_layout(2, 4);

#ifdef DEBUG
printf("idx_0_0: %d, swizzled_idx_0_0: %d\n", idx_0_0, swizzled_idx_0_0);
printf("idx_1_0: %d, swizzled_idx_1_0: %d\n", idx_1_0, swizzled_idx_1_0);
printf("idx_1_4: %d, swizzled_idx_1_4: %d\n", idx_1_4, swizzled_idx_1_4);
printf("idx_2_0: %d, swizzled_idx_2_0: %d\n", idx_2_0, swizzled_idx_2_0);
printf("idx_2_4: %d, swizzled_idx_2_4: %d\n", idx_2_4, swizzled_idx_2_4);
#endif
}

} // namespace tilefusion::testing

0 comments on commit 953b6e0

Please sign in to comment.