Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/cute/arch/copy_sm90_desc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ to_CUtensorMapDataType() {
inline CUtensorMapSwizzle
to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) {
switch (t) {
default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!");
default: throw std::runtime_error("Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!");
case SmemSwizzleBits::DISABLE:
assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 0B swizzle bits.");
return CU_TENSOR_MAP_SWIZZLE_NONE;
Expand All @@ -251,7 +251,7 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) {
return CU_TENSOR_MAP_SWIZZLE_64B;
case SmemSwizzleBits::B128:
switch (b) {
default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!");
default: throw std::runtime_error("Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!");
case SmemSwizzleBase::SWIZZLE_BASE_16B: return CU_TENSOR_MAP_SWIZZLE_128B;

#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ > 6)))
Expand All @@ -265,7 +265,7 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) {
inline CUtensorMapFloatOOBfill
to_CUtensorMapFloatOOBfill(OOBFill const& t) {
switch(t) {
default: assert(false && "Unknown OOBFill!");
default: throw std::runtime_error("Unknown OOBFill!");
case OOBFill::ZERO: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
case OOBFill::CONSTANT: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA;
}
Expand All @@ -274,7 +274,7 @@ to_CUtensorMapFloatOOBfill(OOBFill const& t) {
inline CUtensorMapL2promotion
to_CUtensorMapL2promotion(L2Promotion const& t) {
switch(t) {
default: assert(false && "Unknown L2Promotion!");
default: throw std::runtime_error("Unknown L2Promotion!");
case L2Promotion::DISABLE: return CU_TENSOR_MAP_L2_PROMOTION_NONE;
case L2Promotion::B64: return CU_TENSOR_MAP_L2_PROMOTION_L2_64B;
case L2Promotion::B128: return CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
Expand Down
12 changes: 5 additions & 7 deletions include/cute/atom/copy_traits_sm90_tma_swizzle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,11 @@ TMA::SmemSwizzleBits
get_tma_swizzle_bits(Swizzle<B,M,S>)
{
if constexpr (M == 4) {
switch (B) {
default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle.");
case 3: return TMA::SmemSwizzleBits::B128;
case 2: return TMA::SmemSwizzleBits::B64;
case 1: return TMA::SmemSwizzleBits::B32;
case 0: return TMA::SmemSwizzleBits::DISABLE;
}
static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle.");
if constexpr (B == 3) { return TMA::SmemSwizzleBits::B128; }
if constexpr (B == 2) { return TMA::SmemSwizzleBits::B64; }
if constexpr (B == 1) { return TMA::SmemSwizzleBits::B32; }
if constexpr (B == 0) { return TMA::SmemSwizzleBits::DISABLE; }
} else

if constexpr (M == 5 || M == 6) {
Expand Down