Skip to content

Commit

Permalink
Also split by K
Browse files Browse the repository at this point in the history
I think this covers the motivation for #3616
  • Loading branch information
jacobhinkle committed Jan 2, 2025
1 parent dce16ad commit f5e084c
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 22 deletions.
62 changes: 43 additions & 19 deletions csrc/scheduler/hopper_multi_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,53 @@ void HopperMultipleMatmulScheduler::transformLikeMmaOutput(
bool is_mma_result) {
// TODO Add constraints

auto apply_k_dim_offset = [is_mma_result](int64_t idx) constexpr {
return (is_mma_result) ? idx - 1 : idx;
};

// The input is originally block tiled so that the inner dims are the CTA tile
// size
// Original: [..., M, N(, K)]
// We split this into warp tiles then instruction tiles
tv->split(apply_k_dim_offset(-2), params_->tile_sizes.warp_tile.m);
tv->split(apply_k_dim_offset(-2), getM(params_->mma_macro));
tv->split(apply_k_dim_offset(-1), params_->tile_sizes.warp_tile.n);
tv->split(apply_k_dim_offset(-1), getN(params_->mma_macro));
// After Split: [..., Mo, Mw, Mi, No, Nw, Nwi]
tv->reorder({
{apply_k_dim_offset(-3), apply_k_dim_offset(-5)},
{apply_k_dim_offset(-2), apply_k_dim_offset(-3)},
});
// After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni]

tv->merge(apply_k_dim_offset(-6));
// After Merge: [..., Mo * No, Mio, Nio, Mii, Nii]
tv->axis(apply_k_dim_offset(-5))->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni]
if (is_mma_result) {
// Original: [..., M, N, K]
tv->split(-3, params_->tile_sizes.warp_tile.m);
tv->split(-3, getM(params_->mma_macro));
tv->split(-2, params_->tile_sizes.warp_tile.n);
tv->split(-2, getN(params_->mma_macro));
// K dimension is present for mma_result
tv->split(-1, params_->tile_sizes.warp_tile.k);
tv->split(-1, getK(params_->mma_macro));
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Ko, Kw, Ki]
tv->reorder({
{-9, -9}, // Mo
{-8, -6}, // Mw
{-7, -3}, // Mi
{-6, -8}, // No
{-5, -5}, // Nw
{-4, -2}, // Ni
{-3, -7}, // Ko
{-2, -4}, // Kw
{-1, -1}, // Ki
});
// After Reorder: [..., Mo, No, Ko, Mw, Nw, Kw, Mi, Ni, Ki]
tv->merge(-9);
// After Merge: [..., Mo * No, Ko, Mw, Nw, Kw, Mi, Ni]
tv->axis(-8)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo * No (TIDy), Ko, Mw, Nw, Kw, Mi, Ni, Ki]
} else {
// Original: [..., M, N]
tv->split(-2, params_->tile_sizes.warp_tile.m);
tv->split(-2, getM(params_->mma_macro));
tv->split(-1, params_->tile_sizes.warp_tile.n);
tv->split(-1, getN(params_->mma_macro));
// After Split: [..., Mo, Mw, Mi, No, Nw, Ni]
tv->reorder({
{-3, -5},
{-2, -3},
});
// After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni]
tv->merge(-6);
// After Merge: [..., Mo * No, Mw, Nw, Mi, Ni]
tv->axis(-5)->parallelize(ParallelType::TIDy);
// After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni]
}
}

MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) {
Expand Down
6 changes: 3 additions & 3 deletions tests/cpp/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4275,8 +4275,8 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) {
MatMulTileOptions gemm_tile;
// Regardless of the instruction, this should result in 2 warp groups i.e. 256
// threads
gemm_tile.cta_tile = GemmTile(256, 256, 16);
gemm_tile.warp_tile = GemmTile(128, 128, 16);
gemm_tile.cta_tile = GemmTile(256, 256, 32);
gemm_tile.warp_tile = GemmTile(128, 128, 32);

MatmulParams mparams;
mparams.supported_vec_size = {8, 8, 8};
Expand All @@ -4286,7 +4286,7 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) {
mparams.async_gmem_load_operands = true;
mparams.circular_buffer_options.circular_buffer_smem_write = true;
mparams.circular_buffer_options.circular_buffer_smem_read = false;
mparams.circular_buffer_options.smem_circular_buffer_stage = 2;
mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
mparams.splitk_factor = 1;
// NOTE: disabling smem use for this test since we currrently hit a bank
Expand Down

0 comments on commit f5e084c

Please sign in to comment.