diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp index 095cde5b10..9b9819ce2e 100644 --- a/include/cute/arch/copy_sm90_desc.hpp +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -239,7 +239,7 @@ to_CUtensorMapDataType() { inline CUtensorMapSwizzle to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { switch (t) { - default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); + default: throw std::runtime_error("Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); case SmemSwizzleBits::DISABLE: assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 0B swizzle bits."); return CU_TENSOR_MAP_SWIZZLE_NONE; @@ -251,7 +251,7 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { return CU_TENSOR_MAP_SWIZZLE_64B; case SmemSwizzleBits::B128: switch (b) { - default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); + default: throw std::runtime_error("Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); case SmemSwizzleBase::SWIZZLE_BASE_16B: return CU_TENSOR_MAP_SWIZZLE_128B; #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ > 6))) @@ -265,7 +265,7 @@ to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { inline CUtensorMapFloatOOBfill to_CUtensorMapFloatOOBfill(OOBFill const& t) { switch(t) { - default: assert(false && "Unknown OOBFill!"); + default: throw std::runtime_error("Unknown OOBFill!"); case OOBFill::ZERO: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; case OOBFill::CONSTANT: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA; } @@ -274,7 +274,7 @@ to_CUtensorMapFloatOOBfill(OOBFill const& t) { inline CUtensorMapL2promotion to_CUtensorMapL2promotion(L2Promotion const& t) { switch(t) { - default: assert(false && "Unknown L2Promotion!"); + default: throw std::runtime_error("Unknown L2Promotion!"); case L2Promotion::DISABLE: return CU_TENSOR_MAP_L2_PROMOTION_NONE; case L2Promotion::B64: return CU_TENSOR_MAP_L2_PROMOTION_L2_64B; case L2Promotion::B128: return CU_TENSOR_MAP_L2_PROMOTION_L2_128B; diff --git a/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp b/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp index 9a447896cc..90d616df2e 100644 --- a/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp +++ b/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp @@ -48,13 +48,11 @@ TMA::SmemSwizzleBits get_tma_swizzle_bits(Swizzle) { if constexpr (M == 4) { - switch (B) { - default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); - case 3: return TMA::SmemSwizzleBits::B128; - case 2: return TMA::SmemSwizzleBits::B64; - case 1: return TMA::SmemSwizzleBits::B32; - case 0: return TMA::SmemSwizzleBits::DISABLE; - } + static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + if constexpr (B == 3) { return TMA::SmemSwizzleBits::B128; } + if constexpr (B == 2) { return TMA::SmemSwizzleBits::B64; } + if constexpr (B == 1) { return TMA::SmemSwizzleBits::B32; } + if constexpr (B == 0) { return TMA::SmemSwizzleBits::DISABLE; } } else if constexpr (M == 5 || M == 6) {