forked from microsoft/TileFusion
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
…ft#41) * Add a swizzle functor. * Update swizzle. * Add SwizzleLayout. * Add swizzle test. * Update swizzle test. * Add assert for swizzle. * Update swizzle layout test. * Remove unused codes. * Add a nested basetile swizzle layout. * Update Nested Swizzle Layout test. * small fix. * follow comments.
- Loading branch information
Showing
5 changed files
with
212 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
#include "cuda_utils.hpp" | ||
|
||
#include <cassert> | ||
|
||
namespace tilefusion::cell { | ||
/** | ||
* @brief A swizzle functor. | ||
* A Swizzle can handle 2^B * 2^S * 2^M elements. | ||
*/ | ||
template <const int kB, const int kM, const int kS> | ||
struct Swizzle { | ||
static constexpr int Bbits = kB; | ||
static constexpr int Mbits = kM; | ||
static constexpr int Sbits = kS; | ||
/** | ||
* @brief Apply the swizzle to an index. | ||
* | ||
* @param idx The index in a swizzle chunk of 2^B * 2^S * 2^M elements. | ||
* @return The swizzled index. | ||
*/ | ||
HOST_DEVICE int apply(int idx) const { | ||
// | Bbits | Sbits | Mbits | | ||
// Mbits as mask for the lower bits. | ||
|
||
assert(idx < (1 << (Bbits + Mbits + Sbits))); | ||
|
||
int bs = idx >> Mbits; | ||
// (b, s) as a 2d coordinate. | ||
int y = bs & ((1 << Sbits) - 1); | ||
int x = bs >> Sbits; | ||
|
||
int swizzled_y = x ^ y; | ||
|
||
// Use swizzled_y instead of y and build swizzled idx. | ||
return (x << (Mbits + Sbits)) | (swizzled_y << Mbits) | | ||
(idx & ((1 << Mbits) - 1)); | ||
} | ||
}; | ||
|
||
/** | ||
* @brief Swizzled Layout. | ||
* | ||
* @tparam Layout_ The layout to swizzle. | ||
* @tparam kB The number of bits for B. | ||
* @tparam kM The number of bits for M. | ||
* @tparam kS The number of bits for S. | ||
*/ | ||
template <typename Layout_, const int kB = 3, const int kM = 3, | ||
const int kS = 3> | ||
struct SwizzledLayout { | ||
static constexpr int Bbits = kB; | ||
static constexpr int Mbits = kM; | ||
static constexpr int Sbits = kS; | ||
|
||
using Layout = Layout_; | ||
using Swizzle = Swizzle<Bbits, Mbits, Sbits>; | ||
|
||
/** | ||
* @brief Apply the swizzle in a layout. | ||
* | ||
* @param x Row dimension index, with a total of 2^B rows. | ||
* @param y Column dimension index, with a total of 2^S * 2^M columns. | ||
*/ | ||
HOST_DEVICE auto operator()(int x, int y) const { | ||
int idx = (x << (Mbits + Sbits)) | y; | ||
|
||
assert(idx < (1 << (Bbits + Mbits + Sbits))); | ||
|
||
int swizzled_idx = swizzle_.apply(idx); | ||
int swizzled_x = swizzled_idx >> (Mbits + Sbits); | ||
int swizzled_y = swizzled_idx & ((1 << (Mbits + Sbits)) - 1); | ||
return Layout{}(swizzled_x, swizzled_y); | ||
} | ||
|
||
private: | ||
Swizzle swizzle_; | ||
}; | ||
|
||
} // namespace tilefusion::cell |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "common/test_utils.hpp" | ||
#include "types/mod.hpp" | ||
|
||
namespace tilefusion::testing { | ||
using namespace cell; | ||
namespace tl = tile_layout; | ||
|
||
int flatten(int x, int y, int width) { return x * width + y; } | ||
|
||
template <const int kB, const int kM, const int kS> | ||
int swizzle_ref(int x, int y) { | ||
int b = x; | ||
int s = y >> kM; | ||
|
||
int swizzled_s = b ^ s; | ||
int swizzle_idx = | ||
(b << (kM + kS)) | (swizzled_s << kM) | (y & ((1 << kM) - 1)); | ||
|
||
return swizzle_idx; | ||
} | ||
|
||
template <const int kB, const int kM, const int kS> | ||
int2 test_swizzle(int x, int y) { | ||
Swizzle<kB, kM, kS> swizzle; | ||
int idx = flatten(x, y, 1 << (kS + kM)); | ||
int swizzled_idx = swizzle.apply(idx); | ||
|
||
int ref_swizzled_idx = swizzle_ref<kB, kM, kS>(x, y); | ||
|
||
#ifdef DEBUG | ||
printf("idx: %d, swizzled_idx: %d, ref_swizzled_idx: %d\n", idx, | ||
swizzled_idx, ref_swizzled_idx); | ||
#endif | ||
|
||
return make_int2(swizzled_idx, ref_swizzled_idx); | ||
} | ||
|
||
TEST(TESTSwizzle, test_swizzle_apply) { | ||
const int kB = 3; | ||
const int kM = 3; | ||
const int kS = 3; | ||
|
||
int2 swizzled_idx_0_0 = test_swizzle<kB, kM, kS>(0, 0); | ||
int2 swizzled_idx_1_0 = test_swizzle<kB, kM, kS>(1, 0); | ||
int2 swizzled_idx_1_4 = test_swizzle<kB, kM, kS>(1, 4); | ||
int2 swizzled_idx_2_0 = test_swizzle<kB, kM, kS>(2, 0); | ||
int2 swizzled_idx_2_4 = test_swizzle<kB, kM, kS>(2, 4); | ||
|
||
EXPECT_EQ(swizzled_idx_0_0.x, swizzled_idx_0_0.y); | ||
EXPECT_EQ(swizzled_idx_1_0.x, swizzled_idx_1_0.y); | ||
EXPECT_EQ(swizzled_idx_1_4.x, swizzled_idx_1_4.y); | ||
EXPECT_EQ(swizzled_idx_2_0.x, swizzled_idx_2_0.y); | ||
EXPECT_EQ(swizzled_idx_2_4.x, swizzled_idx_2_4.y); | ||
} | ||
|
||
TEST(TESTSwizzle, test_naive_swizzle_layout) { | ||
const int kB = 3; | ||
const int kM = 3; | ||
const int kS = 3; | ||
|
||
const int kRows = 1 << kB; | ||
const int kCols = 1 << (kM + kS); | ||
|
||
using NaiveRowMajorLayout = tl::RowMajor<kRows, kCols>; | ||
using NaiveSwizzledRowMajorLayout = | ||
SwizzledLayout<NaiveRowMajorLayout, kB, kM, kS>; | ||
|
||
NaiveSwizzledRowMajorLayout naive_row_major_swizzled_layout; | ||
|
||
EXPECT_EQ((naive_row_major_swizzled_layout(0, 0)), | ||
(swizzle_ref<kB, kM, kS>(0, 0))); | ||
EXPECT_EQ((naive_row_major_swizzled_layout(1, 0)), | ||
(swizzle_ref<kB, kM, kS>(1, 0))); | ||
EXPECT_EQ((naive_row_major_swizzled_layout(1, 4)), | ||
(swizzle_ref<kB, kM, kS>(1, 4))); | ||
EXPECT_EQ((naive_row_major_swizzled_layout(2, 0)), | ||
(swizzle_ref<kB, kM, kS>(2, 0))); | ||
EXPECT_EQ((naive_row_major_swizzled_layout(2, 4)), | ||
(swizzle_ref<kB, kM, kS>(2, 4))); | ||
} | ||
|
||
TEST(TESTSwizzle, test_nested_basetile_swizzle_layout) { | ||
const int kB = 3; | ||
const int kM = 3; | ||
const int kS = 3; | ||
|
||
const int kRows = 1 << kB; | ||
const int kCols = 1 << (kM + kS); | ||
|
||
using NestedBaseTileLayout = | ||
tl::detail::SharedLayout<kRows, kCols, kCols * 16, 16, | ||
tl::Layout::kRowMajor>; | ||
using NestedBaseTileSwizzledLayout = | ||
SwizzledLayout<NestedBaseTileLayout, kB, kM, kS>; | ||
|
||
NestedBaseTileLayout nested_base_tile_layout; | ||
NestedBaseTileSwizzledLayout nested_base_tile_swizzled_layout; | ||
|
||
int idx_0_0 = nested_base_tile_layout(0, 0); | ||
int idx_1_0 = nested_base_tile_layout(1, 0); | ||
int idx_1_4 = nested_base_tile_layout(1, 4); | ||
int idx_2_0 = nested_base_tile_layout(2, 0); | ||
int idx_2_4 = nested_base_tile_layout(2, 4); | ||
|
||
int swizzled_idx_0_0 = nested_base_tile_swizzled_layout(0, 0); | ||
int swizzled_idx_1_0 = nested_base_tile_swizzled_layout(1, 0); | ||
int swizzled_idx_1_4 = nested_base_tile_swizzled_layout(1, 4); | ||
int swizzled_idx_2_0 = nested_base_tile_swizzled_layout(2, 0); | ||
int swizzled_idx_2_4 = nested_base_tile_swizzled_layout(2, 4); | ||
|
||
#ifdef DEBUG | ||
printf("idx_0_0: %d, swizzled_idx_0_0: %d\n", idx_0_0, swizzled_idx_0_0); | ||
printf("idx_1_0: %d, swizzled_idx_1_0: %d\n", idx_1_0, swizzled_idx_1_0); | ||
printf("idx_1_4: %d, swizzled_idx_1_4: %d\n", idx_1_4, swizzled_idx_1_4); | ||
printf("idx_2_0: %d, swizzled_idx_2_0: %d\n", idx_2_0, swizzled_idx_2_0); | ||
printf("idx_2_4: %d, swizzled_idx_2_4: %d\n", idx_2_4, swizzled_idx_2_4); | ||
#endif | ||
} | ||
|
||
} // namespace tilefusion::testing |