Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ sm120_compute_tile_shape_or_override() {
}
}
else {
return Shape<_64, _32>{};
constexpr int EpiN = CTA_N < 32 ? CTA_N : 32;
constexpr int EpiM = CTA_N < 16 ? 128 : 64;
return Shape<Int<EpiM>, Int<EpiN>>{};
}
}
} // EpilogueTileAuto
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ struct CollectiveBuilder<
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedCooperative, BuilderScheduleTag> ||
cute::is_same_v<KernelScheduleAuto, BuilderScheduleTag>)
&&
!((cute::is_base_of_v<KernelTmaWarpSpecializedPingpong, BuilderScheduleTag> ||
cute::is_base_of_v<KernelPtrArrayTmaWarpSpecializedPingpong, BuilderScheduleTag>) &&
cute::tuple_element_t<1, TileShape_MNK>::value < 16)
&&
// Alignment check
detail::sm1xx_blockscaled_gemm_is_aligned<typename detail::blockscaled::blockscaled_type<BuilderScheduleTag, ElementPairA>::data_type,
AlignmentA,
Expand All @@ -98,7 +102,7 @@ struct CollectiveBuilder<
static_assert(cute::is_static_v<ClusterShape_MNK>, "Cluster has to be static");
static_assert(detail::blockscaled::check_input_datatypes<BuilderScheduleTag, ElementPairA, ElementPairB, UmmaMajorA, UmmaMajorB>(), "Incorrect input types");
static_assert(cute::size(ClusterShape_MNK{}) == Int<1>{}, "no programmatic multicast on this arch");
static_assert(size<1>(TileShape_MNK{}) >= 32, "Invalid tile shape N.");
static_assert(size<1>(TileShape_MNK{}) >= 8, "Invalid tile shape N.");

static constexpr auto Instr = detail::blockscaled::select_instr<ElementPairA,
ElementPairB,
Expand All @@ -108,7 +112,7 @@ struct CollectiveBuilder<
BuilderScheduleTag>();
static constexpr bool UseMxf8f6f4 = Instr == detail::blockscaled::BlockScaledInstr::MXF4F6F8;
using PermTileM = decltype(cute::min(size<0>(TileShape_MNK{}), _128{}));
using PermTileN = decltype(detail::sm120_tile_n_permute_selector<SFVectorSize>());
using PermTileN = decltype(detail::sm120_tile_n_permute_selector<SFVectorSize, size<1>(TileShape_MNK{})>());
using PermTileK = cute::conditional_t<(UseMxf8f6f4
), _32, _64>;

Expand All @@ -125,7 +129,9 @@ struct CollectiveBuilder<
"TileSize and MNK Major does not met with MMA Mix 8-bit TMA load requirement" );

using AtomLayoutMNK = cute::conditional_t<IsCooperative,
Layout<Shape<_4,_2,_1>>, Layout<Shape<_2,_2,_1>>>;
cute::conditional_t<(size<1>(TileShape_MNK{}) >= 16), Layout<Shape<_4,_2,_1>>,
Layout<Shape<_8,_1,_1>>>,
Layout<Shape<_2,_2,_1>>>;

using TiledMma = decltype(cute::make_tiled_mma(
cute::rr_blockscaled_op_selector_sm120<ElementA,
Expand Down Expand Up @@ -166,7 +172,8 @@ struct CollectiveBuilder<
>()), SmemAllocTypeA>;
using SmemCopyAtomB = Copy_Atom<decltype(detail::sm120_rr_smem_copy_selector_B<ElementA,
ElementB,
UseMxf8f6f4
UseMxf8f6f4,
size<1>(TileShape_MNK{})
>()), SmemAllocTypeB>;

using SmemCopyAtomSF = Copy_Atom<UniversalCopy<SmemAllocTypeSF>, SmemAllocTypeSF>; // auto-vectorized LDS
Expand Down
55 changes: 41 additions & 14 deletions include/cutlass/gemm/collective/builders/sm120_common.inl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ sm120_rr_smem_copy_selector_A() {
template <
class ElementA,
class ElementB,
bool UseF8f6f4
bool UseF8f6f4,
int TileShapeN = 32
>
CUTLASS_HOST_DEVICE constexpr
auto
Expand All @@ -83,14 +84,38 @@ sm120_rr_smem_copy_selector_B() {
return SM100_SU6_DU8x16_x4_LDSM_N{};
}
else if constexpr (sizeof_bits_v<ElementB> == 4) {
return SM100_SU4_DU8x16_x4_LDSM_N{};
if constexpr (TileShapeN < 16) {
return SM100_SU4_DU8x16_x1_LDSM_N{};
}
else if constexpr (TileShapeN < 32) {
return SM100_SU4_DU8x16_x2_LDSM_N{};
}
else {
return SM100_SU4_DU8x16_x4_LDSM_N{};
}
}
else {
return SM75_U32x4_LDSM_N{};
if constexpr (TileShapeN < 16) {
return SM75_U32x1_LDSM_N{};
}
else if constexpr (TileShapeN < 32) {
return SM75_U32x2_LDSM_N{};
}
else {
return SM75_U32x4_LDSM_N{};
}
}
}
}
else {
return SM75_U32x4_LDSM_N{};
if constexpr (TileShapeN < 16) {
return SM75_U32x1_LDSM_N{};
}
else if constexpr (TileShapeN < 32) {
return SM75_U32x2_LDSM_N{};
}
else {
return SM75_U32x4_LDSM_N{};
}
}
}

Expand Down Expand Up @@ -140,22 +165,24 @@ sm120_rr_smem_selector_sparse() {
}
}

template <int SFVectorSize>
template <int SFVectorSize, int TileShapeN = 32>
CUTLASS_HOST_DEVICE constexpr
auto
sm120_tile_n_permute_selector() {
// VS = 16
if constexpr (SFVectorSize == 16) {
// Permute in the N mode to allow a warp to own all the elements needed for SF reduction
static_assert(SFVectorSize == 16 || SFVectorSize == 32,
"Unsupported SFVectorSize for SM120 collective builder.");
if constexpr (TileShapeN >= 32) {
return cute::Layout<cute::Shape<_8,_2,_2>, cute::Stride<_1, _16,_8>>{};
}
// VS = 32
else if constexpr (SFVectorSize == 32) {
return cute::Layout<cute::Shape<_8,_2,_2>, cute::Stride<_1, _16,_8>>{};
else if constexpr (TileShapeN == 16) {
return cute::Layout<cute::Shape<_8,_2>, cute::Stride<_1,_8>>{};
}
else if constexpr (TileShapeN == 8) {
return cute::Layout<cute::Shape<_8>, cute::Stride<_1>>{};
}
else {
static_assert(cutlass::detail::dependent_false<cute::C<SFVectorSize>>,
"Unsupported SFVectorSize for SM120 collective builder.");
static_assert(cutlass::detail::dependent_false<cute::C<TileShapeN>>,
"TileShape_N must be at least 8 for SM120 blockscaled.");
}
}

Expand Down