From c4e701bbc27f4272ed897c156abd282b11f9d9d9 Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Mon, 1 Jun 2026 14:11:11 -0700 Subject: [PATCH 1/2] add tileN = 8,16 --- .../collective/builders/sm120_builder.inl | 4 +- .../sm120_blockscaled_mma_builder.inl | 11 ++-- .../gemm/collective/builders/sm120_common.inl | 55 ++++++++++++++----- 3 files changed, 51 insertions(+), 19 deletions(-) diff --git a/include/cutlass/epilogue/collective/builders/sm120_builder.inl b/include/cutlass/epilogue/collective/builders/sm120_builder.inl index 200d79f673..1a0f9f621c 100644 --- a/include/cutlass/epilogue/collective/builders/sm120_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm120_builder.inl @@ -191,7 +191,9 @@ sm120_compute_tile_shape_or_override() { } } else { - return Shape<_64, _32>{}; + constexpr int EpiN = CTA_N < 32 ? CTA_N : 32; + constexpr int EpiM = CTA_N < 16 ? 128 : 64; + return Shape, Int>{}; } } } // EpilogueTileAuto diff --git a/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl b/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl index 8c32c32204..0d227cd054 100755 --- a/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl @@ -98,7 +98,7 @@ struct CollectiveBuilder< static_assert(cute::is_static_v, "Cluster has to be static"); static_assert(detail::blockscaled::check_input_datatypes(), "Incorrect input types"); static_assert(cute::size(ClusterShape_MNK{}) == Int<1>{}, "no programmatic multicast on this arch"); - static_assert(size<1>(TileShape_MNK{}) >= 32, "Invalid tile shape N."); + static_assert(size<1>(TileShape_MNK{}) >= 8, "Invalid tile shape N."); static constexpr auto Instr = detail::blockscaled::select_instr(); static constexpr bool UseMxf8f6f4 = Instr == detail::blockscaled::BlockScaledInstr::MXF4F6F8; using PermTileM = decltype(cute::min(size<0>(TileShape_MNK{}), _128{})); - using PermTileN = decltype(detail::sm120_tile_n_permute_selector()); + using PermTileN = decltype(detail::sm120_tile_n_permute_selector(TileShape_MNK{})>()); using PermTileK = cute::conditional_t<(UseMxf8f6f4 ), _32, _64>; @@ -125,7 +125,9 @@ struct CollectiveBuilder< "TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" ); using AtomLayoutMNK = cute::conditional_t>, Layout>>; + cute::conditional_t<(size<1>(TileShape_MNK{}) >= 16), Layout>, + Layout>>, + Layout>>; using TiledMma = decltype(cute::make_tiled_mma( cute::rr_blockscaled_op_selector_sm120()), SmemAllocTypeA>; using SmemCopyAtomB = Copy_Atom(TileShape_MNK{}) >()), SmemAllocTypeB>; using SmemCopyAtomSF = Copy_Atom, SmemAllocTypeSF>; // auto-vectorized LDS diff --git a/include/cutlass/gemm/collective/builders/sm120_common.inl b/include/cutlass/gemm/collective/builders/sm120_common.inl index a1ccdd1e58..7a8172cadd 100644 --- a/include/cutlass/gemm/collective/builders/sm120_common.inl +++ b/include/cutlass/gemm/collective/builders/sm120_common.inl @@ -73,7 +73,8 @@ sm120_rr_smem_copy_selector_A() { template < class ElementA, class ElementB, - bool UseF8f6f4 + bool UseF8f6f4, + int TileShapeN = 32 > CUTLASS_HOST_DEVICE constexpr auto @@ -83,14 +84,38 @@ sm120_rr_smem_copy_selector_B() { return SM100_SU6_DU8x16_x4_LDSM_N{}; } else if constexpr (sizeof_bits_v == 4) { - return SM100_SU4_DU8x16_x4_LDSM_N{}; + if constexpr (TileShapeN < 16) { + return SM100_SU4_DU8x16_x1_LDSM_N{}; + } + else if constexpr (TileShapeN < 32) { + return SM100_SU4_DU8x16_x2_LDSM_N{}; + } + else { + return SM100_SU4_DU8x16_x4_LDSM_N{}; + } } else { - return SM75_U32x4_LDSM_N{}; + if constexpr (TileShapeN < 16) { + return SM75_U32x1_LDSM_N{}; + } + else if constexpr (TileShapeN < 32) { + return SM75_U32x2_LDSM_N{}; + } + else { + return SM75_U32x4_LDSM_N{}; + } } - } + } else { - return SM75_U32x4_LDSM_N{}; + if constexpr (TileShapeN < 16) { + return SM75_U32x1_LDSM_N{}; + } + else if constexpr (TileShapeN < 32) { + return SM75_U32x2_LDSM_N{}; + } + else { + return SM75_U32x4_LDSM_N{}; + } } } @@ -140,22 +165,24 @@ sm120_rr_smem_selector_sparse() { } } -template +template CUTLASS_HOST_DEVICE constexpr auto sm120_tile_n_permute_selector() { - // VS = 16 - if constexpr (SFVectorSize == 16) { - // Permute in the N mode to allow a warp to own all the elements needed for SF reduction + static_assert(SFVectorSize == 16 || SFVectorSize == 32, + "Unsupported SFVectorSize for SM120 collective builder."); + if constexpr (TileShapeN >= 32) { return cute::Layout, cute::Stride<_1, _16,_8>>{}; } - // VS = 32 - else if constexpr (SFVectorSize == 32) { - return cute::Layout, cute::Stride<_1, _16,_8>>{}; + else if constexpr (TileShapeN == 16) { + return cute::Layout, cute::Stride<_1,_8>>{}; + } + else if constexpr (TileShapeN == 8) { + return cute::Layout, cute::Stride<_1>>{}; } else { - static_assert(cutlass::detail::dependent_false>, - "Unsupported SFVectorSize for SM120 collective builder."); + static_assert(cutlass::detail::dependent_false>, + "TileShape_N must be at least 8 for SM120 blockscaled."); } } From e0ee3e40734c46123d28ece3536fdff714d50cb1 Mon Sep 17 00:00:00 2001 From: Brayden Zhong Date: Thu, 11 Jun 2026 20:05:30 -0700 Subject: [PATCH 2/2] add a guard --- .../collective/builders/sm120_blockscaled_mma_builder.inl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl b/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl index 0d227cd054..2835751f41 100755 --- a/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm120_blockscaled_mma_builder.inl @@ -73,6 +73,10 @@ struct CollectiveBuilder< cute::is_base_of_v || cute::is_same_v) && + !((cute::is_base_of_v || + cute::is_base_of_v) && + cute::tuple_element_t<1, TileShape_MNK>::value < 16) + && // Alignment check detail::sm1xx_blockscaled_gemm_is_aligned::data_type, AlignmentA,