|
8 | 8 |
|
9 | 9 | #include <scheduler/pointwise_tma.h> |
10 | 10 |
|
| 11 | +#include <ir/utils.h> |
| 12 | +#include <scheduler/debug_utils.h> |
| 13 | +#include <scheduler/pointwise_utils.h> |
| 14 | +#include <scheduler/runtime_info.h> |
| 15 | +#include <scheduler/tools/inlining.h> |
| 16 | +#include <scheduler/utils.h> |
| 17 | +#include <transform_iter.h> |
| 18 | +#include <transform_replay.h> |
| 19 | + |
11 | 20 | namespace nvfuser { |
12 | 21 | namespace pointwise { |
13 | 22 | namespace tma { |
| 23 | +// TODO: This can be further relaxed to allow more tensor views with fewer |
| 24 | +// dimensions, e.g. outer broadcast input [B, I] can also be loaded with Tma. |
| 25 | +bool isTvSuitableForTma(const TensorView* tv, int64_t n_valid_dims) { |
| 26 | + return scheduler_utils::nLogicalDims(tv) == n_valid_dims; |
| 27 | +}; |
| 28 | + |
| 29 | +// Return total bits if loads one element from each tma loaded input. |
| 30 | +// Used to derive number of elements we should load from each input to satisfy |
| 31 | +// the required bits in flight. |
| 32 | +int64_t getInputBitsPerElement( |
| 33 | + const pointwise_utils::FusionRuntimeProperties& prop) { |
| 34 | + int64_t bits_per_element = 0; |
| 35 | + int64_t n_valid_dims = scheduler_utils::nLogicalDims(prop.largest_out); |
| 36 | + for (const auto& tv : prop.vectorizable_inputs_outputs) { |
| 37 | + if (tv->isFusionInput() && isTvSuitableForTma(tv, n_valid_dims)) { |
| 38 | + bits_per_element += |
| 39 | + dataTypeSizeBit(tv->getDataType().value(), prop.index_type); |
| 40 | + } |
| 41 | + } |
| 42 | + return bits_per_element; |
| 43 | +} |
14 | 44 |
|
15 | 45 | std::unique_ptr<PointwiseParams> getPointwiseHeuristics( |
16 | 46 | Fusion* fusion, |
17 | 47 | SchedulerRuntimeInfo& runtime_info, |
18 | | - HeuristicDataCache* data_cache) { |
19 | | - FusionGuard fg(fusion); |
| 48 | + HeuristicDataCache* data_cache, |
| 49 | + const pointwise_utils::FusionRuntimeProperties& prop) { |
| 50 | + // Hardware constants |
| 51 | + constexpr int64_t threads_per_warp = 32; |
| 52 | + constexpr int64_t max_size_per_tma_tile_dim = 256; |
20 | 53 | auto params = std::make_unique<PointwiseParams>(); |
21 | 54 | params->tag = "Pointwise TMA heuristics"; |
| 55 | + params->cparams.index_type = prop.index_type; |
22 | 56 | params->use_tma_load = true; |
23 | | - NVF_THROW("Schedule pointwise using TMA"); |
| 57 | + |
| 58 | + // Compute TMA inner domain size |
| 59 | + constexpr int64_t target_inner_tma_domain_size = 512; |
| 60 | + int64_t tma_domain_inner = scheduler_utils::getInnerTmaDomainSize( |
| 61 | + prop.n_elems, |
| 62 | + target_inner_tma_domain_size, |
| 63 | + prop.min_dtype_size_bit_for_vectorization); |
| 64 | + NVF_ERROR( |
| 65 | + tma_domain_inner > 1 && prop.n_elems % tma_domain_inner == 0, |
| 66 | + "Ilegal TMA inner domain size: ", |
| 67 | + tma_domain_inner, |
| 68 | + ", n_elems: ", |
| 69 | + prop.n_elems); |
| 70 | + const int64_t tma_outer_domain_size = prop.n_elems / tma_domain_inner; |
| 71 | + params->tma_domain_inner = tma_domain_inner; |
| 72 | + |
| 73 | + // set elements_per_cta: Each CTA issues one TMA load operation, set number of |
| 74 | + // elements of this TMA load operation. Assuming 8 CTAs per SM, using |
| 75 | + // empirical required bits in flight, it is just a guidance, actual tile size |
| 76 | + // is set by tma_tile_inner and tma_tile_outer. |
| 77 | + // Inner tila size: ensure we have at least 2 tiles in the inner TMA |
| 78 | + // dimension. outer tile size: don't exceed the outer TMA dimension size Both |
| 79 | + // are subject to hardware constraints of 256 elements per dimension. |
| 80 | + constexpr int64_t cta_per_sm = 8; |
| 81 | + int64_t bits_per_sm = scheduler_utils::getRequiredBitsInFlight(); |
| 82 | + int64_t bits_per_cta = bits_per_sm / cta_per_sm; |
| 83 | + int64_t bits_per_element = getInputBitsPerElement(prop); |
| 84 | + int64_t elements_per_cta = ceilDiv(bits_per_cta, bits_per_element); |
| 85 | + elements_per_cta = scheduler_utils::roundUpPow2Or8(elements_per_cta); |
| 86 | + int64_t max_tma_tile_inner = |
| 87 | + std::min(tma_domain_inner / 2, max_size_per_tma_tile_dim); |
| 88 | + int64_t max_tma_tile_outer = |
| 89 | + std::min(tma_outer_domain_size, max_size_per_tma_tile_dim); |
| 90 | + int64_t tma_tile_inner = std::min(tma_domain_inner / 2, threads_per_warp); |
| 91 | + while (tma_tile_inner * 2 <= max_tma_tile_inner) { |
| 92 | + tma_tile_inner *= 2; |
| 93 | + } |
| 94 | + int64_t tma_tile_outer = |
| 95 | + std::min(elements_per_cta / tma_tile_inner, max_tma_tile_outer); |
| 96 | + params->tma_tile_inner = tma_tile_inner; |
| 97 | + params->tma_tile_outer = tma_tile_outer; |
| 98 | + |
| 99 | + // set block tile size, typical setup is 32 in x-dim and 4 in y-dim. |
| 100 | + // but don't go beyond tma tile size in each dimension. |
| 101 | + constexpr int64_t threads_per_cta = 128; |
| 102 | + int64_t bdimx = std::min(threads_per_warp, tma_tile_inner); |
| 103 | + int64_t bdimy = std::min(threads_per_cta / bdimx, tma_tile_outer); |
| 104 | + params->lparams.bindUnsafe(bdimx, ParallelType::TIDx); |
| 105 | + params->lparams.bindUnsafe(bdimy, ParallelType::TIDy); |
| 106 | + |
| 107 | + // set vectorization factor gmem <--> regs |
| 108 | + // [tma_tile_inner] is scheduled as [S, TIDx, Vect], so vectorization factor |
| 109 | + // can't exceed tma_tile_inner / bdimx. |
| 110 | + NVF_ERROR( |
| 111 | + tma_tile_inner % bdimx == 0, "tma_tile_inner must be divisible by bdimx"); |
| 112 | + constexpr int64_t max_vectorization_size_in_bit = 128; |
| 113 | + int64_t vect_factor_dtype = |
| 114 | + max_vectorization_size_in_bit / prop.max_dtype_size_bit_for_vectorization; |
| 115 | + int64_t vect_factor_tma_tile_size = tma_tile_inner / bdimx; |
| 116 | + params->vectorization_factor = |
| 117 | + std::min(vect_factor_dtype, vect_factor_tma_tile_size); |
| 118 | + |
| 119 | + // TMA store |
| 120 | + params->use_tma_store = false; |
| 121 | + |
| 122 | + if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { |
| 123 | + debug() << "\n==== Pointwise TMA Scheduler Heuristics ====\n"; |
| 124 | + debug() << "Domain sizes:\n"; |
| 125 | + debug() << " n_elems: " << prop.n_elems << "\n"; |
| 126 | + debug() << " tma_domain_inner: " << tma_domain_inner << "\n"; |
| 127 | + debug() << " tma_outer_domain_size: " << tma_outer_domain_size << "\n"; |
| 128 | + debug() << "\nMemory and CTA configuration:\n"; |
| 129 | + debug() << " cta_per_sm: " << cta_per_sm << "\n"; |
| 130 | + debug() << " bits_per_sm: " << bits_per_sm << "\n"; |
| 131 | + debug() << " bits_per_cta: " << bits_per_cta << "\n"; |
| 132 | + debug() << " bits_per_element: " << bits_per_element << "\n"; |
| 133 | + debug() << " elements_per_cta: " << elements_per_cta << "\n"; |
| 134 | + debug() << "\nTMA tile configuration:\n"; |
| 135 | + debug() << " max_size_per_tma_tile_dim: " << max_size_per_tma_tile_dim |
| 136 | + << "\n"; |
| 137 | + debug() << " max_tma_tile_inner: " << max_tma_tile_inner << "\n"; |
| 138 | + debug() << " tma_tile_inner: " << tma_tile_inner << "\n"; |
| 139 | + debug() << " tma_tile_outer: " << tma_tile_outer << "\n"; |
| 140 | + debug() << " tma_tile_size: " << (tma_tile_inner * tma_tile_outer) << "\n"; |
| 141 | + debug() << " use_tma_load: " << params->use_tma_load << "\n"; |
| 142 | + debug() << " use_tma_store: " << params->use_tma_store << "\n"; |
| 143 | + debug() << "\nThread block configuration:\n"; |
| 144 | + debug() << " blockDim.x (TIDx): " << bdimx << "\n"; |
| 145 | + debug() << " blockDim.y (TIDy): " << bdimy << "\n"; |
| 146 | + debug() << " threads_per_cta: " << (bdimx * bdimy) << "\n"; |
| 147 | + debug() << "\nVectorization:\n"; |
| 148 | + debug() << " max_dtype_size_bit: " |
| 149 | + << prop.max_dtype_size_bit_for_vectorization << "\n"; |
| 150 | + debug() << " min_dtype_size_bit: " |
| 151 | + << prop.min_dtype_size_bit_for_vectorization << "\n"; |
| 152 | + debug() << " max_vectorization_size_in_bit: " |
| 153 | + << max_vectorization_size_in_bit << "\n"; |
| 154 | + debug() << " vectorization_factor: " << params->vectorization_factor |
| 155 | + << "\n"; |
| 156 | + debug() << "============================================\n" << std::endl; |
| 157 | + } |
24 | 158 | return params; |
25 | 159 | } |
26 | 160 |
|
27 | 161 | // TODO: Inline intermediate operations (avoid inlining unrolled/vectorized |
28 | 162 | // input/output caches) |
29 | 163 | void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { |
30 | 164 | FusionGuard fg(fusion); |
31 | | - NVF_THROW("Schedule pointwise using TMA"); |
| 165 | + |
| 166 | + // always merge all dimensions without considering break point |
| 167 | + // it can be equivalently considered in setting TMA domain sizes |
| 168 | + auto schedule_info_opt = |
| 169 | + pointwise_utils::commonPointwiseSchedule(fusion, /*break_point=*/0); |
| 170 | + if (!schedule_info_opt.has_value()) { |
| 171 | + // Zero-dimensional tensors, nothing to schedule |
| 172 | + return; |
| 173 | + } |
| 174 | + auto& schedule_info = schedule_info_opt.value(); |
| 175 | + |
| 176 | + auto& cached_inputs = schedule_info.cached_inputs; |
| 177 | + auto& cached_outputs = schedule_info.cached_outputs; |
| 178 | + TensorView* reference_tv = schedule_info.reference_tv; |
| 179 | + auto inputs_outputs = |
| 180 | + scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true, true); |
| 181 | + std::unordered_set<TensorView*> vectorizable_io_tvs( |
| 182 | + inputs_outputs.begin(), inputs_outputs.end()); |
| 183 | + |
| 184 | + // For each cached input, use TMA load if it has full logical domains as the |
| 185 | + // reference tv, otherwise use LDG. |
| 186 | + int64_t n_valid_dims = scheduler_utils::nLogicalDims(reference_tv); |
| 187 | + std::vector<TensorView*> tma_tvs; |
| 188 | + std::vector<TensorView*> ldg_tvs; |
| 189 | + for (const auto& [tv, _] : cached_inputs) { |
| 190 | + if (!isTvSuitableForTma(tv, n_valid_dims)) { |
| 191 | + ldg_tvs.push_back(tv); |
| 192 | + continue; |
| 193 | + } |
| 194 | + auto load_op = dynamic_cast<LoadStoreOp*>(tv->definition()); |
| 195 | + if (load_op) { |
| 196 | + load_op->setOpType(LoadStoreOpType::CpAsyncBulkTensorTile); |
| 197 | + } |
| 198 | + tv->setMemoryType(MemoryType::Shared); |
| 199 | + tv->cacheAfter(); |
| 200 | + tma_tvs.push_back(tv); |
| 201 | + } |
| 202 | + |
| 203 | + // Schedule the TMA domain [I0] -> [Do, Di] |
| 204 | + reference_tv->split(0, pparams->tma_domain_inner); |
| 205 | + |
| 206 | + // Schedule the TMA box/tile |
| 207 | + // [Do, Di] -> [Do/to, to, Di/ti, ti] |
| 208 | + reference_tv->split(1, pparams->tma_tile_inner); |
| 209 | + reference_tv->split(0, pparams->tma_tile_outer); |
| 210 | + |
| 211 | + // Propagate the TMA related transformation to all tensors. |
| 212 | + TransformPropagator propagator(reference_tv); |
| 213 | + MaxLogicalDomainInfoSpanningTree(reference_tv).traverse(&propagator); |
| 214 | + |
| 215 | + // parallelize tma tvs, |
| 216 | + // after propagation,reset reference back to serial to further schedule |
| 217 | + // non-tma parts. |
| 218 | + auto outer_cord_pt = ParallelType::BIDy; |
| 219 | + auto inner_cord_pt = ParallelType::BIDx; |
| 220 | + if (pparams->flip_grid_binding) { |
| 221 | + std::swap(outer_cord_pt, inner_cord_pt); |
| 222 | + } |
| 223 | + reference_tv->axis(0)->parallelize(outer_cord_pt); |
| 224 | + reference_tv->axis(1)->parallelize(ParallelType::Bulk); |
| 225 | + reference_tv->axis(2)->parallelize(inner_cord_pt); |
| 226 | + reference_tv->axis(3)->parallelize(ParallelType::Bulk); |
| 227 | + scheduler_utils::parallelizeAllLike(reference_tv, tma_tvs); |
| 228 | + reference_tv->axis(1)->parallelize(ParallelType::Serial); |
| 229 | + reference_tv->axis(3)->parallelize(ParallelType::Serial); |
| 230 | + |
| 231 | + // Further schedule non-TMA part, start with [Do/to, to, Di/ti, ti] |
| 232 | + // [Do/to, to, Di/ti, ti] -> [Do/to, to/y, y, Di/ti, ti/v/x, x, v] |
| 233 | + int64_t opos = 1, ipos = 3; |
| 234 | + reference_tv->split(ipos, pparams->vectorization_factor); |
| 235 | + reference_tv->split(ipos, pparams->lparams.bdimx()); |
| 236 | + reference_tv->split(opos, pparams->lparams.bdimy()); |
| 237 | + |
| 238 | + // propagate transformation to non-tma tvs |
| 239 | + std::vector<TensorView*> non_tma_tvs = |
| 240 | + ir_utils::allTvsExcept(fusion, {tma_tvs.begin(), tma_tvs.end()}); |
| 241 | + TransformPropagator non_tma_propagator(reference_tv); |
| 242 | + SetSelector selector({non_tma_tvs.begin(), non_tma_tvs.end()}); |
| 243 | + MaxLogicalDomainInfoSpanningTree(reference_tv, &selector) |
| 244 | + .traverse(&non_tma_propagator); |
| 245 | + |
| 246 | + reference_tv->axis(0)->parallelize(outer_cord_pt); |
| 247 | + reference_tv->axis(2)->parallelize(ParallelType::TIDy); |
| 248 | + reference_tv->axis(3)->parallelize(inner_cord_pt); |
| 249 | + reference_tv->axis(5)->parallelize(ParallelType::TIDx); |
| 250 | + int64_t vect_pos = 6; // save position for vectorization |
| 251 | + scheduler_utils::parallelizeAllLike(reference_tv, non_tma_tvs); |
| 252 | + |
| 253 | + // vectorize regs -> global |
| 254 | + if (!pparams->use_tma_store && pparams->vectorization_factor > 1) { |
| 255 | + for (const auto& [_, original_idx] : cached_outputs) { |
| 256 | + auto output_tv = |
| 257 | + dynamic_cast<TensorView*>(fusion->outputs().at(original_idx)); |
| 258 | + if (output_tv && vectorizable_io_tvs.contains(output_tv)) { |
| 259 | + output_tv->axis(vect_pos)->parallelize(ParallelType::Vectorize); |
| 260 | + } |
| 261 | + } |
| 262 | + } |
| 263 | + |
| 264 | + // vectorize global -> regs |
| 265 | + for (auto ldg_tv : ldg_tvs) { |
| 266 | + if (vectorizable_io_tvs.contains(ldg_tv)) { |
| 267 | + ldg_tv->axis(vect_pos)->parallelize(ParallelType::Vectorize); |
| 268 | + } |
| 269 | + } |
| 270 | + |
| 271 | + // inline all |
| 272 | + inlineMost(); |
32 | 273 | } |
| 274 | + |
33 | 275 | } // namespace tma |
34 | 276 | } // namespace pointwise |
35 | 277 | } // namespace nvfuser |
0 commit comments