From f5e084c8ae62227b9f54876ff525a107ff2f64de Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 2 Jan 2025 10:32:45 -0500 Subject: [PATCH] Also split by K I think this covers the motivation for #3616 --- csrc/scheduler/hopper_multi_matmul.cpp | 62 ++++++++++++++++++-------- tests/cpp/test_matmul.cpp | 6 +-- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index aeb0fbc0325..022148d4a84 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -34,29 +34,53 @@ void HopperMultipleMatmulScheduler::transformLikeMmaOutput( bool is_mma_result) { // TODO Add constraints - auto apply_k_dim_offset = [is_mma_result](int64_t idx) constexpr { - return (is_mma_result) ? idx - 1 : idx; - }; - // The input is originally block tiled so that the inner dims are the CTA tile // size // Original: [..., M, N(, K)] // We split this into warp tiles then instruction tiles - tv->split(apply_k_dim_offset(-2), params_->tile_sizes.warp_tile.m); - tv->split(apply_k_dim_offset(-2), getM(params_->mma_macro)); - tv->split(apply_k_dim_offset(-1), params_->tile_sizes.warp_tile.n); - tv->split(apply_k_dim_offset(-1), getN(params_->mma_macro)); - // After Split: [..., Mo, Mw, Mi, No, Nw, Nwi] - tv->reorder({ - {apply_k_dim_offset(-3), apply_k_dim_offset(-5)}, - {apply_k_dim_offset(-2), apply_k_dim_offset(-3)}, - }); - // After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni] - - tv->merge(apply_k_dim_offset(-6)); - // After Merge: [..., Mo * No, Mio, Nio, Mii, Nii] - tv->axis(apply_k_dim_offset(-5))->parallelize(ParallelType::TIDy); - // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni] + if (is_mma_result) { + // Original: [..., M, N, K] + tv->split(-3, params_->tile_sizes.warp_tile.m); + tv->split(-3, getM(params_->mma_macro)); + tv->split(-2, params_->tile_sizes.warp_tile.n); + tv->split(-2, getN(params_->mma_macro)); + // K dimension is present for mma_result + tv->split(-1, params_->tile_sizes.warp_tile.k); + tv->split(-1, getK(params_->mma_macro)); + // After Split: [..., Mo, Mw, Mi, No, Nw, Ni, Ko, Kw, Ki] + tv->reorder({ + {-9, -9}, // Mo + {-8, -6}, // Mw + {-7, -3}, // Mi + {-6, -8}, // No + {-5, -5}, // Nw + {-4, -2}, // Ni + {-3, -7}, // Ko + {-2, -4}, // Kw + {-1, -1}, // Ki + }); + // After Reorder: [..., Mo, No, Ko, Mw, Nw, Kw, Mi, Ni, Ki] + tv->merge(-9); + // After Merge: [..., Mo * No, Ko, Mw, Nw, Kw, Mi, Ni] + tv->axis(-8)->parallelize(ParallelType::TIDy); + // After Parallelize: [..., Mo * No (TIDy), Ko, Mw, Nw, Kw, Mi, Ni, Ki] + } else { + // Original: [..., M, N] + tv->split(-2, params_->tile_sizes.warp_tile.m); + tv->split(-2, getM(params_->mma_macro)); + tv->split(-1, params_->tile_sizes.warp_tile.n); + tv->split(-1, getN(params_->mma_macro)); + // After Split: [..., Mo, Mw, Mi, No, Nw, Ni] + tv->reorder({ + {-3, -5}, + {-2, -3}, + }); + // After Reorder: [..., Mo, No, Mw, Nw, Mi, Ni] + tv->merge(-6); + // After Merge: [..., Mo * No, Mw, Nw, Mi, Ni] + tv->axis(-5)->parallelize(ParallelType::TIDy); + // After Parallelize: [..., Mo * No (TIDy), Mw, Nw, Mi, Ni] + } } MatmulDimRole HopperMultipleMatmulScheduler::findMatmulDimRole(IterDomain* id) { diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 521913772f1..2378eed5599 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -4275,8 +4275,8 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) { MatMulTileOptions gemm_tile; // Regardless of the instruction, this should result in 2 warp groups i.e. 256 // threads - gemm_tile.cta_tile = GemmTile(256, 256, 16); - gemm_tile.warp_tile = GemmTile(128, 128, 16); + gemm_tile.cta_tile = GemmTile(256, 256, 32); + gemm_tile.warp_tile = GemmTile(128, 128, 32); MatmulParams mparams; mparams.supported_vec_size = {8, 8, 8}; @@ -4286,7 +4286,7 @@ TEST_F(HopperMatmulTest, HSH_NT_UseScheduler_MultipleInstructionsPerWarpTile) { mparams.async_gmem_load_operands = true; mparams.circular_buffer_options.circular_buffer_smem_write = true; mparams.circular_buffer_options.circular_buffer_smem_read = false; - mparams.circular_buffer_options.smem_circular_buffer_stage = 2; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; mparams.splitk_factor = 1; // NOTE: disabling smem use for this test since we currrently hit a bank