Skip to content

Commit dbc4bba

Browse files
committed
add auto tma scheduler
1 parent bfb8add commit dbc4bba

File tree

4 files changed

+368
-29
lines changed

4 files changed

+368
-29
lines changed

csrc/scheduler/pointwise.cpp

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <scheduler/pointwise.h>
1010

11+
#include <ATen/cuda/CUDAContext.h>
1112
#include <instrumentation.h>
1213
#include <scheduler/debug_utils.h>
1314
#include <scheduler/pointwise_non_tma.h>
@@ -16,7 +17,6 @@
1617
#include <scheduler/registry_utils.h>
1718
#include <scheduler/runtime_info.h>
1819
#include <scheduler/utils.h>
19-
2020
#include <ranges>
2121

2222
namespace nvfuser {
@@ -329,6 +329,65 @@ bool PointWiseScheduler::canScheduleRunTime(
329329
return true;
330330
}
331331

332+
namespace {
333+
334+
// TODO: refine this function to check the contiguity, broadcast, reshape, etc.
335+
bool mayHaveTmaCompatibleInputs(
336+
const pointwise_utils::FusionRuntimeProperties& prop) {
337+
for (auto tv : prop.vectorizable_inputs_outputs) {
338+
if (!tv->isFusionInput()) {
339+
continue;
340+
}
341+
auto dtype_bits =
342+
dataTypeSizeBit(tv->getDataType().value(), prop.index_type);
343+
// actual element count should consider breakpoint and compute individually
344+
// for each input. here the largest output is used as the reference. If we
345+
// fail with this largest value, then it guarantees no input is suitable for
346+
// Tma since all inputs are smaller than the largest output in pointwise.
347+
auto elem_count = prop.n_elems;
348+
auto total_bits = elem_count * dtype_bits;
349+
// function-condition-1, TMA requires size divisible by 16 bytes (128 bits)
350+
if (total_bits % 128 != 0) {
351+
continue;
352+
}
353+
// function-condition-2, We only do 2D TMA, requires at least 2 boxes in
354+
// inner dimension each with 16 bytes. This requires a minimum innter tma
355+
// domain size of 2 * 16 bytes. We also should skip if the inner tma domain
356+
// size is exactly the same as the element count. This means outer tma
357+
// domain is 1, which is not a valid 2D TMA domain.
358+
const int64_t min_inner_tma_domain_size = 2 * 128 / dtype_bits;
359+
if (elem_count % min_inner_tma_domain_size != 0 ||
360+
elem_count == min_inner_tma_domain_size) {
361+
continue;
362+
}
363+
// TODO: check reshape, contiguity, allocation domain, etc.
364+
// function-condition-3, reshape, contiguity, allocation domain, etc.
365+
// TODO: performance checks
366+
// performance-condition-1, input size is too small
367+
// performance-condition-2, Innner TMA domain size is too small
368+
369+
// pass all preliminary checks, may be suitable for TMA.
370+
return true;
371+
}
372+
return false;
373+
}
374+
// Preliminary check if TMA can be used for the fusion. Serves as a fast path to
375+
// avoid computing heuristics if TMA is obviously not possible. Passing this
376+
// check does not guarantee that TMA will be used, as the actual TMA usage will
377+
// be determined by the heuristics.
378+
bool mayUseTma(const pointwise_utils::FusionRuntimeProperties& prop) {
379+
// Harware, Don't use tma for pre-Blackwell GPUs
380+
if (at::cuda::getCurrentDeviceProperties()->major < 10) {
381+
return false;
382+
}
383+
// Inputs
384+
if (!mayHaveTmaCompatibleInputs(prop)) {
385+
return false;
386+
}
387+
return true;
388+
}
389+
} // namespace
390+
332391
std::unique_ptr<HeuristicParams> PointWiseScheduler::computeHeuristics(
333392
Fusion* fusion,
334393
SchedulerRuntimeInfo& runtime_info,
@@ -346,11 +405,15 @@ std::unique_ptr<HeuristicParams> PointWiseScheduler::computeHeuristics(
346405
}
347406
const auto& prop = prop_opt.value();
348407

408+
<<<<<<< HEAD
409+
bool use_tma = mayUseTma(prop) && isOptionEnabled(EnableOption::TmaPointwise);
410+
=======
349411
bool use_tma = false;
412+
>>>>>>> origin/llu/pt2_utils
350413
std::unique_ptr<HeuristicParams> pparams = nullptr;
351414
if (use_tma) {
352415
pparams = pointwise::tma::getPointwiseHeuristics(
353-
fusion, runtime_info, data_cache);
416+
fusion, runtime_info, data_cache, prop);
354417
} else {
355418
pparams = pointwise::non_tma::getPointwiseHeuristics(
356419
fusion, runtime_info, data_cache, prop);

csrc/scheduler/pointwise_tma.cpp

Lines changed: 246 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,28 +8,270 @@
88

99
#include <scheduler/pointwise_tma.h>
1010

11+
#include <ir/utils.h>
12+
#include <scheduler/debug_utils.h>
13+
#include <scheduler/pointwise_utils.h>
14+
#include <scheduler/runtime_info.h>
15+
#include <scheduler/tools/inlining.h>
16+
#include <scheduler/utils.h>
17+
#include <transform_iter.h>
18+
#include <transform_replay.h>
19+
1120
namespace nvfuser {
1221
namespace pointwise {
1322
namespace tma {
23+
// TODO: This can be further relaxed to allow more tensor views with fewer
24+
// dimensions, e.g. outer broadcast input [B, I] can also be loaded with Tma.
25+
bool isTvSuitableForTma(const TensorView* tv, int64_t n_valid_dims) {
26+
return scheduler_utils::nLogicalDims(tv) == n_valid_dims;
27+
};
28+
29+
// Return total bits if loads one element from each tma loaded input.
30+
// Used to derive number of elements we should load from each input to satisfy
31+
// the required bits in flight.
32+
int64_t getInputBitsPerElement(
33+
const pointwise_utils::FusionRuntimeProperties& prop) {
34+
int64_t bits_per_element = 0;
35+
int64_t n_valid_dims = scheduler_utils::nLogicalDims(prop.largest_out);
36+
for (const auto& tv : prop.vectorizable_inputs_outputs) {
37+
if (tv->isFusionInput() && isTvSuitableForTma(tv, n_valid_dims)) {
38+
bits_per_element +=
39+
dataTypeSizeBit(tv->getDataType().value(), prop.index_type);
40+
}
41+
}
42+
return bits_per_element;
43+
}
1444

1545
std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
1646
Fusion* fusion,
1747
SchedulerRuntimeInfo& runtime_info,
18-
HeuristicDataCache* data_cache) {
19-
FusionGuard fg(fusion);
48+
HeuristicDataCache* data_cache,
49+
const pointwise_utils::FusionRuntimeProperties& prop) {
50+
// Hardware constants
51+
constexpr int64_t threads_per_warp = 32;
52+
constexpr int64_t max_size_per_tma_tile_dim = 256;
2053
auto params = std::make_unique<PointwiseParams>();
2154
params->tag = "Pointwise TMA heuristics";
55+
params->cparams.index_type = prop.index_type;
2256
params->use_tma_load = true;
23-
NVF_THROW("Schedule pointwise using TMA");
57+
58+
// Compute TMA inner domain size
59+
constexpr int64_t target_inner_tma_domain_size = 512;
60+
int64_t tma_domain_inner = scheduler_utils::getInnerTmaDomainSize(
61+
prop.n_elems,
62+
target_inner_tma_domain_size,
63+
prop.min_dtype_size_bit_for_vectorization);
64+
NVF_ERROR(
65+
tma_domain_inner > 1 && prop.n_elems % tma_domain_inner == 0,
66+
"Ilegal TMA inner domain size: ",
67+
tma_domain_inner,
68+
", n_elems: ",
69+
prop.n_elems);
70+
const int64_t tma_outer_domain_size = prop.n_elems / tma_domain_inner;
71+
params->tma_domain_inner = tma_domain_inner;
72+
73+
// set elements_per_cta: Each CTA issues one TMA load operation, set number of
74+
// elements of this TMA load operation. Assuming 8 CTAs per SM, using
75+
// empirical required bits in flight, it is just a guidance, actual tile size
76+
// is set by tma_tile_inner and tma_tile_outer.
77+
// Inner tila size: ensure we have at least 2 tiles in the inner TMA
78+
// dimension. outer tile size: don't exceed the outer TMA dimension size Both
79+
// are subject to hardware constraints of 256 elements per dimension.
80+
constexpr int64_t cta_per_sm = 8;
81+
int64_t bits_per_sm = scheduler_utils::getRequiredBitsInFlight();
82+
int64_t bits_per_cta = bits_per_sm / cta_per_sm;
83+
int64_t bits_per_element = getInputBitsPerElement(prop);
84+
int64_t elements_per_cta = ceilDiv(bits_per_cta, bits_per_element);
85+
elements_per_cta = scheduler_utils::roundUpPow2Or8(elements_per_cta);
86+
int64_t max_tma_tile_inner =
87+
std::min(tma_domain_inner / 2, max_size_per_tma_tile_dim);
88+
int64_t max_tma_tile_outer =
89+
std::min(tma_outer_domain_size, max_size_per_tma_tile_dim);
90+
int64_t tma_tile_inner = std::min(tma_domain_inner / 2, threads_per_warp);
91+
while (tma_tile_inner * 2 <= max_tma_tile_inner) {
92+
tma_tile_inner *= 2;
93+
}
94+
int64_t tma_tile_outer =
95+
std::min(elements_per_cta / tma_tile_inner, max_tma_tile_outer);
96+
params->tma_tile_inner = tma_tile_inner;
97+
params->tma_tile_outer = tma_tile_outer;
98+
99+
// set block tile size, typical setup is 32 in x-dim and 4 in y-dim.
100+
// but don't go beyond tma tile size in each dimension.
101+
constexpr int64_t threads_per_cta = 128;
102+
int64_t bdimx = std::min(threads_per_warp, tma_tile_inner);
103+
int64_t bdimy = std::min(threads_per_cta / bdimx, tma_tile_outer);
104+
params->lparams.bindUnsafe(bdimx, ParallelType::TIDx);
105+
params->lparams.bindUnsafe(bdimy, ParallelType::TIDy);
106+
107+
// set vectorization factor gmem <--> regs
108+
// [tma_tile_inner] is scheduled as [S, TIDx, Vect], so vectorization factor
109+
// can't exceed tma_tile_inner / bdimx.
110+
NVF_ERROR(
111+
tma_tile_inner % bdimx == 0, "tma_tile_inner must be divisible by bdimx");
112+
constexpr int64_t max_vectorization_size_in_bit = 128;
113+
int64_t vect_factor_dtype =
114+
max_vectorization_size_in_bit / prop.max_dtype_size_bit_for_vectorization;
115+
int64_t vect_factor_tma_tile_size = tma_tile_inner / bdimx;
116+
params->vectorization_factor =
117+
std::min(vect_factor_dtype, vect_factor_tma_tile_size);
118+
119+
// TMA store
120+
params->use_tma_store = false;
121+
122+
if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) {
123+
debug() << "\n==== Pointwise TMA Scheduler Heuristics ====\n";
124+
debug() << "Domain sizes:\n";
125+
debug() << " n_elems: " << prop.n_elems << "\n";
126+
debug() << " tma_domain_inner: " << tma_domain_inner << "\n";
127+
debug() << " tma_outer_domain_size: " << tma_outer_domain_size << "\n";
128+
debug() << "\nMemory and CTA configuration:\n";
129+
debug() << " cta_per_sm: " << cta_per_sm << "\n";
130+
debug() << " bits_per_sm: " << bits_per_sm << "\n";
131+
debug() << " bits_per_cta: " << bits_per_cta << "\n";
132+
debug() << " bits_per_element: " << bits_per_element << "\n";
133+
debug() << " elements_per_cta: " << elements_per_cta << "\n";
134+
debug() << "\nTMA tile configuration:\n";
135+
debug() << " max_size_per_tma_tile_dim: " << max_size_per_tma_tile_dim
136+
<< "\n";
137+
debug() << " max_tma_tile_inner: " << max_tma_tile_inner << "\n";
138+
debug() << " tma_tile_inner: " << tma_tile_inner << "\n";
139+
debug() << " tma_tile_outer: " << tma_tile_outer << "\n";
140+
debug() << " tma_tile_size: " << (tma_tile_inner * tma_tile_outer) << "\n";
141+
debug() << " use_tma_load: " << params->use_tma_load << "\n";
142+
debug() << " use_tma_store: " << params->use_tma_store << "\n";
143+
debug() << "\nThread block configuration:\n";
144+
debug() << " blockDim.x (TIDx): " << bdimx << "\n";
145+
debug() << " blockDim.y (TIDy): " << bdimy << "\n";
146+
debug() << " threads_per_cta: " << (bdimx * bdimy) << "\n";
147+
debug() << "\nVectorization:\n";
148+
debug() << " max_dtype_size_bit: "
149+
<< prop.max_dtype_size_bit_for_vectorization << "\n";
150+
debug() << " min_dtype_size_bit: "
151+
<< prop.min_dtype_size_bit_for_vectorization << "\n";
152+
debug() << " max_vectorization_size_in_bit: "
153+
<< max_vectorization_size_in_bit << "\n";
154+
debug() << " vectorization_factor: " << params->vectorization_factor
155+
<< "\n";
156+
debug() << "============================================\n" << std::endl;
157+
}
24158
return params;
25159
}
26160

27161
// TODO: Inline intermediate operations (avoid inlining unrolled/vectorized
28162
// input/output caches)
29163
void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) {
30164
FusionGuard fg(fusion);
31-
NVF_THROW("Schedule pointwise using TMA");
165+
166+
// always merge all dimensions without considering break point
167+
// it can be equivalently considered in setting TMA domain sizes
168+
auto schedule_info_opt =
169+
pointwise_utils::commonPointwiseSchedule(fusion, /*break_point=*/0);
170+
if (!schedule_info_opt.has_value()) {
171+
// Zero-dimensional tensors, nothing to schedule
172+
return;
173+
}
174+
auto& schedule_info = schedule_info_opt.value();
175+
176+
auto& cached_inputs = schedule_info.cached_inputs;
177+
auto& cached_outputs = schedule_info.cached_outputs;
178+
TensorView* reference_tv = schedule_info.reference_tv;
179+
auto inputs_outputs =
180+
scheduler_utils::getInputsOutputsWithInnerDim(reference_tv, true, true);
181+
std::unordered_set<TensorView*> vectorizable_io_tvs(
182+
inputs_outputs.begin(), inputs_outputs.end());
183+
184+
// For each cached input, use TMA load if it has full logical domains as the
185+
// reference tv, otherwise use LDG.
186+
int64_t n_valid_dims = scheduler_utils::nLogicalDims(reference_tv);
187+
std::vector<TensorView*> tma_tvs;
188+
std::vector<TensorView*> ldg_tvs;
189+
for (const auto& [tv, _] : cached_inputs) {
190+
if (!isTvSuitableForTma(tv, n_valid_dims)) {
191+
ldg_tvs.push_back(tv);
192+
continue;
193+
}
194+
auto load_op = dynamic_cast<LoadStoreOp*>(tv->definition());
195+
if (load_op) {
196+
load_op->setOpType(LoadStoreOpType::CpAsyncBulkTensorTile);
197+
}
198+
tv->setMemoryType(MemoryType::Shared);
199+
tv->cacheAfter();
200+
tma_tvs.push_back(tv);
201+
}
202+
203+
// Schedule the TMA domain [I0] -> [Do, Di]
204+
reference_tv->split(0, pparams->tma_domain_inner);
205+
206+
// Schedule the TMA box/tile
207+
// [Do, Di] -> [Do/to, to, Di/ti, ti]
208+
reference_tv->split(1, pparams->tma_tile_inner);
209+
reference_tv->split(0, pparams->tma_tile_outer);
210+
211+
// Propagate the TMA related transformation to all tensors.
212+
TransformPropagator propagator(reference_tv);
213+
MaxLogicalDomainInfoSpanningTree(reference_tv).traverse(&propagator);
214+
215+
// parallelize tma tvs,
216+
// after propagation,reset reference back to serial to further schedule
217+
// non-tma parts.
218+
auto outer_cord_pt = ParallelType::BIDy;
219+
auto inner_cord_pt = ParallelType::BIDx;
220+
if (pparams->flip_grid_binding) {
221+
std::swap(outer_cord_pt, inner_cord_pt);
222+
}
223+
reference_tv->axis(0)->parallelize(outer_cord_pt);
224+
reference_tv->axis(1)->parallelize(ParallelType::Bulk);
225+
reference_tv->axis(2)->parallelize(inner_cord_pt);
226+
reference_tv->axis(3)->parallelize(ParallelType::Bulk);
227+
scheduler_utils::parallelizeAllLike(reference_tv, tma_tvs);
228+
reference_tv->axis(1)->parallelize(ParallelType::Serial);
229+
reference_tv->axis(3)->parallelize(ParallelType::Serial);
230+
231+
// Further schedule non-TMA part, start with [Do/to, to, Di/ti, ti]
232+
// [Do/to, to, Di/ti, ti] -> [Do/to, to/y, y, Di/ti, ti/v/x, x, v]
233+
int64_t opos = 1, ipos = 3;
234+
reference_tv->split(ipos, pparams->vectorization_factor);
235+
reference_tv->split(ipos, pparams->lparams.bdimx());
236+
reference_tv->split(opos, pparams->lparams.bdimy());
237+
238+
// propagate transformation to non-tma tvs
239+
std::vector<TensorView*> non_tma_tvs =
240+
ir_utils::allTvsExcept(fusion, {tma_tvs.begin(), tma_tvs.end()});
241+
TransformPropagator non_tma_propagator(reference_tv);
242+
SetSelector selector({non_tma_tvs.begin(), non_tma_tvs.end()});
243+
MaxLogicalDomainInfoSpanningTree(reference_tv, &selector)
244+
.traverse(&non_tma_propagator);
245+
246+
reference_tv->axis(0)->parallelize(outer_cord_pt);
247+
reference_tv->axis(2)->parallelize(ParallelType::TIDy);
248+
reference_tv->axis(3)->parallelize(inner_cord_pt);
249+
reference_tv->axis(5)->parallelize(ParallelType::TIDx);
250+
int64_t vect_pos = 6; // save position for vectorization
251+
scheduler_utils::parallelizeAllLike(reference_tv, non_tma_tvs);
252+
253+
// vectorize regs -> global
254+
if (!pparams->use_tma_store && pparams->vectorization_factor > 1) {
255+
for (const auto& [_, original_idx] : cached_outputs) {
256+
auto output_tv =
257+
dynamic_cast<TensorView*>(fusion->outputs().at(original_idx));
258+
if (output_tv && vectorizable_io_tvs.contains(output_tv)) {
259+
output_tv->axis(vect_pos)->parallelize(ParallelType::Vectorize);
260+
}
261+
}
262+
}
263+
264+
// vectorize global -> regs
265+
for (auto ldg_tv : ldg_tvs) {
266+
if (vectorizable_io_tvs.contains(ldg_tv)) {
267+
ldg_tv->axis(vect_pos)->parallelize(ParallelType::Vectorize);
268+
}
269+
}
270+
271+
// inline all
272+
inlineMost();
32273
}
274+
33275
} // namespace tma
34276
} // namespace pointwise
35277
} // namespace nvfuser

csrc/scheduler/pointwise_tma.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ namespace tma {
1616
std::unique_ptr<PointwiseParams> getPointwiseHeuristics(
1717
Fusion* fusion,
1818
SchedulerRuntimeInfo& runtime_info,
19-
HeuristicDataCache* data_cache);
19+
HeuristicDataCache* data_cache,
20+
const pointwise_utils::FusionRuntimeProperties& prop);
2021

2122
void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams);
2223
} // namespace tma

0 commit comments

Comments
 (0)