From d53be4504756e935a4ea7781ec5d441bf523cb83 Mon Sep 17 00:00:00 2001 From: Protonu Date: Sat, 14 Dec 2024 09:46:03 -0500 Subject: [PATCH] Schedule epilogue (for Hopper Matmul) by propagation backward from output - smem epilogue not supported. (#3580) This adds support for scheduling epilogue for the hopper matmul scheduler. We don't support smem epilogue as yet. We also don't honor the vectorization_factor as yet for the store to output. That'll be covered in a separate PR. --- csrc/scheduler/hopper_multi_matmul.cpp | 171 ++++++------------------- csrc/scheduler/hopper_multi_matmul.h | 7 - tests/cpp/test_matmul_scheduler.cpp | 76 ++++++++++- 3 files changed, 111 insertions(+), 143 deletions(-) diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index e23587abf10..91048d3374c 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -434,107 +434,52 @@ void HopperMultipleMatmulScheduler::scheduleMmaResults() { } } -void HopperMultipleMatmulScheduler::scheduleOutputTensor(TensorView* c) { - const MatMulTileOptions& gemm_tile = params_->tile_sizes; - const int64_t vectorization_factor = params_->supported_vec_size.epilogue; - // input tensor is in the form of [Mo,No,cta_tile_m,cta_tile_n] - mma_utils::checkConcreteStaticDim(c->axis(-2)); - mma_utils::checkConcreteStaticDim(c->axis(-1)); - const int64_t tile_size_m = c->axis(-2)->extent()->evaluate().as(); - const int64_t tile_size_n = c->axis(-1)->extent()->evaluate().as(); - NVF_ERROR( - tile_size_m == gemm_tile.cta_tile.m, - "Actual tile size at axis(-2) in output tensor is different from CTA tile size! Expected: ", - gemm_tile.cta_tile.m, - ", actual: ", - tile_size_m); - NVF_ERROR( - tile_size_n == gemm_tile.cta_tile.n, - "Actual tile size at axis(-1) in output tensor is different from CTA tile size! Expected: ", - gemm_tile.cta_tile.n, - ", actual: ", - tile_size_n); - const int64_t tot_elements = tile_size_m * tile_size_n; - constexpr int64_t warp_size = 32l; - const int64_t tidx = warp_size; - const int64_t tidy = gemm_tile.cta_tile.n / gemm_tile.warp_tile.n; - const int64_t tidz = gemm_tile.cta_tile.m / gemm_tile.warp_tile.m; - // step-1, merge last 2 dims - c->merge(-2); - // [Mo, No, m*n] - - // step-2, set vectorization to maximum - // We have fixed tidx, tidy, and tidz, so we need to make sure that the - // output tensor is divisible by tidx * tidy * tidz * vectorization_factor - NVF_ERROR( - tot_elements % (tidx * tidy * tidz * vectorization_factor) == 0, - "Output tensor cannot be fully vectorized! tot_elements:", - tot_elements, - ", tidx: ", - tidx, - ", tidy: ", - tidy, - ", tidz: ", - tidz, - ", vectorization_factor: ", - vectorization_factor); - c->split(-1, vectorization_factor); - c->axis(-1)->parallelize(ParallelType::Vectorize); - // [Mo, No, m*n/vect, vect] - - // step-3, Split out a warp for TIDx - c->split(-2, tidx); - c->axis(-2)->parallelize(ParallelType::TIDx); - // [Mo, No, m*n/vect/TIDx, TIDx, vect] - - // step-4, Split out for TIDy and TIDz - // TIDy = cta_tile_n/warp_tile_n - // TIDz = cta_tile_m/warp_tile_m - c->split(-3, tidy); - c->axis(-3)->parallelize(ParallelType::TIDy); - - c->split(-4, tidz); - c->axis(-4)->parallelize(ParallelType::TIDz); - // [Mo, No, m*n/vect/TIDx/TIDy/TIDz, TIDz, TIDy, TIDx, vect] +void HopperMultipleMatmulScheduler::scheduleEpilogue() { + std::vector cached_tvs; - for (TensorView* mma_result : mma_results_) { - // step-5, Parallel first 2 dims same as mma_result - scheduler_utils::parallelizeAllLike( - mma_result, - 2, - {c}, - {ParallelType::BIDx, ParallelType::BIDy, ParallelType::BIDz}); + // Propagate to (not including) the splitk output if there is a splitk + // else this is just mma_results_ + std::vector propagate_to = + splitk_sums_.empty() ? mma_results_ : splitk_sums_; + if (tensor_roles_.count(MatmulTensorRole::EPILOGUE_INPUT)) { + auto& c_tvs = tensor_roles_.at(MatmulTensorRole::EPILOGUE_INPUT); + // Load/cache the epilogue inputs if there are any. + for (auto* c : c_tvs) { + cached_tvs.push_back(c->cacheAfter()); + } + propagate_to.insert(propagate_to.end(), c_tvs.begin(), c_tvs.end()); } -} -void HopperMultipleMatmulScheduler::scheduleEpilogue() { - // TODO: schedule epilogue by propagation backward from dc if (!params_->use_smem_epilogue) { for (Val* dv : fusion_->outputs()) { auto* d = dv->as(); NVF_ERROR(d->definition() && d->definition()->isA()); - auto* dc = d->definition()->input(0)->as(); - std::vector tvs_to_schedule{d}; - if (std::find(mma_results_.begin(), mma_results_.end(), dc) == - mma_results_.end()) { - // Skip scheduling dc if it is an mma_result. This can happen if we are - // not casting back to half-precision in the output - tvs_to_schedule.push_back(dc); - } - - // Block Schedule and Parallelize - blockTileTensors(tvs_to_schedule); - parallelizeBlocks(tvs_to_schedule); - - // Apply mma common transformation - for (auto tv : tvs_to_schedule) { - transformLikeMmaOutput(tv, /*is_mma_result=*/false); - auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( - tv->getLoopDomain()); - tv->setLoopDomain(s.as()); - } + // Schedule the output TV and propagate it back to the outputs of the Mma + // op. + blockTileTensors({d}); + parallelizeBlocks({d}); + transformLikeMmaOutput(d, /*is_mma_result=*/false); + + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + d->getLoopDomain()); + d->setLoopDomain(s.as()); + + // TODO: We need to check bank conflicts in this path. + scheduler_utils::BoundedDirectionalTransformPropagator::backward( + d, + -1, + propagate_to, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType()); + + // We don't respect vectorization_factor as yet. We vectorize the + // inner-dim with extent 2. + // TODO: support vectorization_factor. d->axis(-1)->parallelize(ParallelType::Vectorize); + if (!cached_tvs.empty()) { + scheduler_utils::parallelizeAllLike(d, -1, cached_tvs); + } } } else { constexpr int64_t stmatrix_tile_m = 16; @@ -609,48 +554,6 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { } } -//! Propagates transformations from fusion output to fusion tv inputs that are -//! producers in the epilogue. Transformations' propagation aims at input tvs -//! which are not assigned to core roles, that is, are not MMA inputs. -void HopperMultipleMatmulScheduler::scheduleFusionInputsForEpilogue() { - std::vector cached_tvs; - - // Handling transformations in fusion input tvs with assigned EPILOGUE_INPUT - // role by propagating fusion output transformations through cached views - // of EPILOGUE_INPUT fusion input tvs and by setting vectorization of the - // inner most iterdomain of these cached views - if (tensor_roles_.count(MatmulTensorRole::EPILOGUE_INPUT)) { - auto& c_tvs = tensor_roles_.at(MatmulTensorRole::EPILOGUE_INPUT); - - // The system supports only scenario where there is only one fusion output - // with assigned OUTPUT role, this condition is already verified so there - // is no need for an additional checks here - auto output_d = tensor_roles_.at(MatmulTensorRole::OUTPUT).front(); - for (auto* c : c_tvs) { - cached_tvs.push_back(c->cacheAfter()); - } - - scheduler_utils::BoundedDirectionalTransformPropagator::backward( - output_d, -1, c_tvs); - - std::unordered_set parallel_types = {}; - if (params_->use_smem_epilogue) { - // In cases where smem epilogue feature is enabled, the vectorization - // of domains will be propagated to fusion inputs that are epilogue - // inputs, this may result in unaligned memory reads. Vectorization is - // explicitly excluded form parallelization types to avoid this issue. - // This should be changed when vectorization analysis is available and - // enabled for matmul scheduler. - parallel_types = allParallelTypesExcept({ParallelType::Vectorize}); - } - scheduler_utils::parallelizeAllLike( - output_d, -1, cached_tvs, parallel_types); - - // The cached EPILOGUE_INPUT tvs are not needed anymore - cached_tvs.clear(); - } -} - void HopperMultipleMatmulScheduler::scheduleSplitKSum() { if (params_->splitk_factor == 1) { return; diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index 5eab0f4fbed..295b55ee96e 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -171,15 +171,8 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { void scheduleMmaResults(); - void scheduleOutputTensor(TensorView* c); - void scheduleEpilogue(); - //! Propagates transformations from fusion output to fusion tv inputs that are - //! producers in the epilogue. Transformations' propagation aims at input tvs - //! which are not assigned to core roles, that is, are not MMA inputs. - void scheduleFusionInputsForEpilogue(); - void scheduleSplitKSum(); void setUpInlining(); diff --git a/tests/cpp/test_matmul_scheduler.cpp b/tests/cpp/test_matmul_scheduler.cpp index a6ddb8d0ca8..0ffde4364c1 100644 --- a/tests/cpp/test_matmul_scheduler.cpp +++ b/tests/cpp/test_matmul_scheduler.cpp @@ -3296,8 +3296,7 @@ class HopperMatmulSchedulerTest KernelExecutor ke; ke.compile(fusion, inputs, LaunchParams(), matmul_cparams); auto nvf_out = ke.run(inputs); - // NOTE Relax tolerances for split-k case - EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-3, 1e-3)); + EXPECT_TRUE(at::allclose(nvf_out.at(0), tref, 1e-2, 1e-2)); } protected: @@ -3377,6 +3376,79 @@ TEST_P(HopperMatmulSchedulerTest, FusedMultiplySum) { tref = atMatmul(A.squeeze(), B.squeeze(), layout); } +// TODO: Remove this test once the architecture agnostic can be +// run on hopper. +TEST_P(HopperMatmulSchedulerTest, FusedMultiplySumBiasNeg) { + if (use_smem_epilogue) { + GTEST_SKIP() + << "TODO: We don't support smem epilogue in the Hopper matmul scheduler right now"; + } + const auto& [A, B] = + matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype)); + const auto& C = matmulAtInput2D( + layout, TensorMatmulPos::Bias, data_type_to_aten(dtype), M, N, K); + inputs = {A, B, C}; + + TensorView* tv0 = nullptr; + TensorView* tv1 = nullptr; + std::unordered_map old2new; + int64_t k_axis = 0; + + switch (layout) { + case MmaLayout::TT: + // Inner dims KN, order is MKN + tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); + tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); + old2new = {{-2, -1}, {-1, -2}}; + k_axis = -2; + break; + case MmaLayout::TN: + // Inner dims KK, order is MNK + tv0 = makeContigConcreteTensor({-1, 1, -1}, dtype); + tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); + old2new = {}; + k_axis = -1; + break; + case MmaLayout::NT: + // Inner dims MN, order is KMN + tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); + tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); + old2new = {{-3, -1}}; + k_axis = -3; + break; + case MmaLayout::NN: + // Inner dims MK, order is NKM + tv0 = makeContigConcreteTensor({1, -1, -1}, dtype); + tv1 = makeContigConcreteTensor({-1, -1, 1}, dtype); + old2new = {{-1, -3}}; + k_axis = -2; + break; + } + TensorView* tv2 = makeContigConcreteTensor({-1}, dtype); + + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addInput(tv2); + + auto tv3 = fusedMultiplySum(tv0, tv1, {k_axis}); + + // Reorder the accumulator as [M, N, K] + tv3->reorder(old2new); + tv3->commitLeafToLogical(); + + auto* tv4 = maybeCastOp(DataType::Float, tv2); + auto* tv5 = biasEpilogue(tv3, tv4); + auto* tv6 = neg(tv5); + auto* tv7 = castOp(dtype, tv6); + fusion->addOutput(tv7); + + tref = atBiasEpilogue( + atMatmul(A.squeeze(), B.squeeze(), layout), + C.to(data_type_to_aten(DataType::Float))) + .neg_() + .to(data_type_to_aten(DataType::Half)); +} + INSTANTIATE_TEST_SUITE_P( General, HopperMatmulSchedulerTest,