-
Notifications
You must be signed in to change notification settings - Fork 70
add tma pointwise scheduler #5553
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Review updated until commit ba3543a Description
|
| Relevant files | |||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|
| Enhancement |
| ||||||||||
| Tests |
| ||||||||||
| Documentation |
| ||||||||||
| Configuration changes |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Division by Zero Risk
|
531c0b8 to
dbc4bba
Compare
dbc4bba to
6d36bb1
Compare
|
!test |
Greptile OverviewGreptile SummaryThis PR adds a TMA (Tensor Memory Accelerator) pointwise scheduler for Hopper and newer GPUs. The implementation includes preliminary TMA compatibility checks, heuristic computation for determining TMA domain and tile sizes, and a complete scheduling pipeline that transforms tensors through a 2D TMA domain hierarchy. Key changes:
Limitations (as documented):
Confidence Score: 4/5
Important Files ChangedFile Analysis
Sequence DiagramsequenceDiagram
participant User
participant PointwiseScheduler
participant TMACheck as mayUseTma()
participant TMAHeuristics as getPointwiseHeuristics()
participant TMASchedule as schedulePointwise()
participant NonTMA as non_tma scheduler
User->>PointwiseScheduler: computeHeuristics()
PointwiseScheduler->>TMACheck: Check TMA compatibility
TMACheck->>TMACheck: GPU >= Hopper (SM 9.0)?
TMACheck->>TMACheck: mayHaveTmaCompatibleInputs()?
alt TMA Compatible
TMACheck-->>PointwiseScheduler: true
PointwiseScheduler->>TMAHeuristics: Compute TMA heuristics
TMAHeuristics->>TMAHeuristics: Compute TMA domain [outer, inner]
TMAHeuristics->>TMAHeuristics: Calculate elements_per_cta
TMAHeuristics->>TMAHeuristics: Determine TMA tile size [to, ti]
TMAHeuristics->>TMAHeuristics: Configure thread blocks (bdimx, bdimy)
TMAHeuristics->>TMAHeuristics: Set vectorization factor
TMAHeuristics-->>PointwiseScheduler: PointwiseParams (TMA)
alt TMA params valid
PointwiseScheduler->>User: Return TMA params
User->>TMASchedule: schedule(fusion, params)
TMASchedule->>TMASchedule: Flatten and merge dims
TMASchedule->>TMASchedule: Split into TMA domain [Do, Di]
TMASchedule->>TMASchedule: Split into tiles [Do/to, to, Di/ti, ti]
TMASchedule->>TMASchedule: Parallelize TMA tensors (Bulk)
TMASchedule->>TMASchedule: Parallelize non-TMA tensors (TIDx, TIDy)
TMASchedule->>TMASchedule: Apply vectorization
TMASchedule-->>User: Scheduled fusion
else TMA params invalid (bits_per_element=0)
TMAHeuristics-->>PointwiseScheduler: nullptr
PointwiseScheduler->>NonTMA: Fallback to non-TMA
NonTMA-->>PointwiseScheduler: Non-TMA params
PointwiseScheduler->>User: Return non-TMA params
end
else Not TMA Compatible
TMACheck-->>PointwiseScheduler: false
PointwiseScheduler->>NonTMA: Use non-TMA scheduler
NonTMA-->>PointwiseScheduler: Non-TMA params
PointwiseScheduler->>User: Return non-TMA params
end
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 6 comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
csrc/scheduler/pointwise_tma.cpp
Outdated
| prop.min_dtype_size_bit_for_vectorization); | ||
| NVF_ERROR( | ||
| tma_domain_inner > 1 && prop.n_elems % tma_domain_inner == 0, | ||
| "Ilegal TMA inner domain size: ", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: typo: "Ilegal" should be "Illegal"
| "Ilegal TMA inner domain size: ", | |
| "Illegal TMA inner domain size: ", |
csrc/scheduler/pointwise_tma.cpp
Outdated
| // elements of this TMA load operation. Assuming 8 CTAs per SM, using | ||
| // empirical required bits in flight, it is just a guidance, actual tile size | ||
| // is set by tma_tile_inner and tma_tile_outer. | ||
| // Inner tila size: ensure we have at least 2 tiles in the inner TMA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: typo: "tila" should be "tile"
| // Inner tila size: ensure we have at least 2 tiles in the inner TMA | |
| // Inner tile size: ensure we have at least 2 tiles in the inner TMA |
csrc/scheduler/pointwise.cpp
Outdated
| continue; | ||
| } | ||
| // function-condition-2, We only do 2D TMA, requires at least 2 boxes in | ||
| // inner dimension each with 16 bytes. This requires a minimum innter tma |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: typo: "innter" should be "inner"
| // inner dimension each with 16 bytes. This requires a minimum innter tma | |
| // domain size of 2 * 16 bytes. We also should skip if the inner tma domain |
csrc/scheduler/pointwise.cpp
Outdated
| // function-condition-3, reshape, contiguity, allocation domain, etc. | ||
| // TODO: performance checks | ||
| // performance-condition-1, input size is too small | ||
| // performance-condition-2, Innner TMA domain size is too small |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: typo: "Innner" should be "Inner"
| // performance-condition-2, Innner TMA domain size is too small | |
| // performance-condition-2, Inner TMA domain size is too small |
csrc/scheduler/pointwise.cpp
Outdated
| // check does not guarantee that TMA will be used, as the actual TMA usage will | ||
| // be determined by the heuristics. | ||
| bool mayUseTma(const pointwise_utils::FusionRuntimeProperties& prop) { | ||
| // Harware, Don't use tma for pre-Blackwell GPUs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: typo: "Harware" should be "Hardware"
| // Harware, Don't use tma for pre-Blackwell GPUs | |
| // Hardware, Don't use tma for pre-Blackwell GPUs |
csrc/scheduler/pointwise_tma.cpp
Outdated
| int64_t bits_per_element = getInputBitsPerElement(prop); | ||
| int64_t elements_per_cta = ceilDiv(bits_per_cta, bits_per_element); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: potential division by zero if no TMA-compatible inputs exist - bits_per_element could be 0, causing division by zero in ceilDiv
| int64_t bits_per_element = getInputBitsPerElement(prop); | |
| int64_t elements_per_cta = ceilDiv(bits_per_cta, bits_per_element); | |
| int64_t bits_per_element = getInputBitsPerElement(prop); | |
| NVF_ERROR(bits_per_element > 0, "No TMA-compatible inputs found"); | |
| int64_t elements_per_cta = ceilDiv(bits_per_cta, bits_per_element); |
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
csrc/scheduler/pointwise_tma.cpp
Outdated
| int64_t tma_tile_outer = | ||
| std::min(elements_per_cta / tma_tile_inner, max_tma_tile_outer); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: tma_tile_outer could be 0 if elements_per_cta < tma_tile_inner (integer division rounds down)
| int64_t tma_tile_outer = | |
| std::min(elements_per_cta / tma_tile_inner, max_tma_tile_outer); | |
| int64_t tma_tile_outer = | |
| std::max(1L, std::min(elements_per_cta / tma_tile_inner, max_tma_tile_outer)); |
csrc/scheduler/pointwise.cpp
Outdated
| auto total_bits = elem_count * dtype_bits; | ||
|
|
||
| // Condition 1: TMA requires size divisible by 16 bytes (128 bits) | ||
| if (total_bits % 128 != 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't the multi-dimensional TMA support predication?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it supports.
The comment is not clear. The first step is creaing a 2D TMA domain, which requires divisible split and inner dim must be multiple-16 bytes. This condition can be merged with the following condition-2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 2 comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
| // Hardware requirement: Don't use TMA for pre-Hopper GPUs | ||
| if (at::cuda::getCurrentDeviceProperties()->major < 9) { | ||
| return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: PR description says "Blackwell GPUs" but code checks major < 9 which enables TMA for Hopper (SM 9.0) and above, not Blackwell (SM 10.0) and above. Check the intended GPU requirement.
csrc/scheduler/pointwise_tma.cpp
Outdated
| int64_t tma_tile_outer = | ||
| std::min(elements_per_cta / tma_tile_inner, max_tma_tile_outer); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: tma_tile_outer will be 0 if elements_per_cta < tma_tile_inner (integer division rounds down), causing split(0) failure at line 214 and potentially bdimy=0 at line 108.
| int64_t tma_tile_outer = | |
| std::min(elements_per_cta / tma_tile_inner, max_tma_tile_outer); | |
| int64_t tma_tile_outer = std::max( | |
| 1L, std::min(elements_per_cta / tma_tile_inner, max_tma_tile_outer)); |
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
Edit Code Review Agent Settings | Greptile
React with 👍 or 👎 to share your feedback on this new summary format
| const auto& prop = prop_opt.value(); | ||
|
|
||
| bool use_tma = false; | ||
| bool use_tma = mayUseTma(prop) && isOptionEnabled(EnableOption::TmaPointwise); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not structure it in the following way instead:
Move the isOptionEnabled() check inside mayUseTma():
if (!isOptionEnabled(EnableOption::TmaPointwise)) {
return false;
}
Then in this function, you only need to write:
if (mayUseTma(prop)) {
pparams = pointwise::tma::getPointwiseHeuristics(
fusion, runtime_info, data_cache, prop);
}
This is slightly more concise. But more importantly, it moves the isOptionEnabled() check next to the getCurrentDeviceProperties() check, which seems similar.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the suggestion. I prefer keeping the current approach becuase of the following reason:
mayUseTma(prop) says we can use tma-version, isOptionEnabled(EnableOption::TmaPointwise) says user want to use tma-version. So they are different things.
In the future, bool use_tma = mayUseTma(prop) && ; will be revised to bool use_tma = mayUseTma(prop) || isOptionEnabled(EnableOption::TmaPointwise); clearly passing the msg that tma version will be used because of one of the following conditions:
- (1) heuristic auto selection
mayUseTma(prop) - (2) User select to use it
csrc/scheduler/pointwise_tma.cpp
Outdated
| const pointwise_utils::FusionRuntimeProperties& prop) { | ||
| // Hardware constants | ||
| constexpr int64_t threads_per_warp = 32; | ||
| constexpr int64_t max_size_per_tma_tile_dim = 256; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this represent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
256 is the maximum elements per TMA dimension. Needs a better var name ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see. It wasn't immediately clear to me what the TMA tile dim means.
csrc/scheduler/pointwise_tma.cpp
Outdated
| int64_t max_tma_tile_inner = | ||
| std::min(tma_domain_inner / 2, max_size_per_tma_tile_dim); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this? Why is tma_domain_inner divided by 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is similar to what we did in getInnerTmaDomainSize, we want to keep at least 2 boxes/tiles in the inner dimension to avoid creating two contiguous bulk dimensions which will be merged and may lead to a bulk size larger than 256.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized you have this in the above:
// - Inner tile size: ensure at least 2 tiles in the inner TMA dimension
I still easily get lost with various TMA "XYZ", e.g., tiles, dimensions, domains, boxes, etc. I'd appreciate if you could be more strict with the naming. For example:
// - Inner tile size: ensure at least 2 tiles in the inner TMA domain
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I got your point. Let me revise these var names and comments.
csrc/scheduler/pointwise_tma.cpp
Outdated
| int64_t tma_tile_inner = std::min(tma_domain_inner / 2, threads_per_warp); | ||
| while (tma_tile_inner * 2 <= max_tma_tile_inner) { | ||
| tma_tile_inner *= 2; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm lost here again. Can you please give a comment to each variable here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revised variable names and comments:
Variable Name Changes Summary
Function Names
getInnerTmaDomainSize()→getTmaDomainInner()- More concise, matches the variable it computes
Function Parameters
-
target_inner_tma_domain_size→tma_domain_inner_target- Reordered to match pattern:
<variable>_<modifier>
- Reordered to match pattern:
-
min_dtype_bytes→min_dtype_bits- Changed from bytes to bits for consistency with codebase
- Default value:
1byte →8bits
Local Variables (in pointwise_tma.cpp)
max_tma_tile_inner→tma_tile_inner_maxmax_tma_tile_outer→tma_tile_outer_max- Reordered for consistency:
<variable>_maxpattern
- Reordered for consistency:
Terminology Standardization
Throughout all files, consistently use:
tma_domain_outerandtma_domain_inner(notOuterTmaDomain,Do, ortma_outer_domain_size)tma_tile_outerandtma_tile_inner(notto,ti, or abbreviated forms)
Rationale: All changes improve consistency by following the pattern <concept>_<dimension>_<modifier> (e.g., tma_tile_inner_max), making the code more readable and searchable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, no comments
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, no comments
| // ============================================================================ | ||
| // TMA TERMINOLOGY GUIDE: | ||
| // ============================================================================ | ||
| // The TMA scheduler uses a 2-level hierarchy to organize data: | ||
| // | ||
| // 1. TMA DOMAIN: The logical split of the entire problem space into two parts: | ||
| // - tma_domain_inner: Size of the inner (contiguous) dimension | ||
| // - tma_domain_outer: Size of the outer dimension (n_elems / | ||
| // tma_domain_inner) | ||
| // - requirement: n_elems % tma_domain_inner == 0 | ||
| // Example: For 2048 elements with tma_domain_inner=512: | ||
| // TMA Domain structure: [4, 512] | ||
| // | ||
| // 2. TMA TILE: The actual hardware tile size for each TMA load operation | ||
| // - tma_tile_inner: Number of elements along the inner dimension per tile | ||
| // - tma_tile_outer: Number of elements along the outer dimension per tile | ||
| // Example: For TMA domain [4, 512] with tiles [2, 128]: | ||
| // Each TMA load fetches a [2 x 128] tile, requiring 2 x 4 = 8 tiles total | ||
| // | ||
| // Note: In general TMA terminology, a "box" is the dense rectangular region | ||
| // loaded from global memory, while a "tile" is a potentially strided subset | ||
| // selected from that box. The pointwise scheduler always uses dense tiles | ||
| // (tile = box), so we use "TMA tile" throughout to refer to both concepts. | ||
| // | ||
| // Transformation sequence: logical domain -> TMA domain -> TMA tile | ||
| // [I0, I1, ...] -- > [I] -> [Do, Di] -> [Do/to, to, Di/ti, ti] | ||
| // where Do=tma_domain_outer, Di=tma_domain_inner, | ||
| // to=tma_tile_outer, ti=tma_tile_inner | ||
| // ============================================================================ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. This is super helpful!
| params->use_tma_load = true; | ||
| NVF_THROW("Schedule pointwise using TMA"); | ||
|
|
||
| // ========== Step 1: Compute TMA Domain Dimensions ========== |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also helps a lot. Thanks!
csrc/scheduler/pointwise_tma.cpp
Outdated
| // process each TMA tile. Threads cooperate to move data from shared memory | ||
| // to registers and perform computation. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I appreciate detailed comments, but this is nothing specific to TMA and, yeah, you can assume everyone involved in nvFuser would know what shared memory is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead, could you explain the relationship between TMA tiles and thread blocks? For example, I'm not sure why blockDim.x is the minimum of 32 and the inner tile size. tma_tile_inner is a multiple of 32 per lines 157-159, so it seems bdimx is always 32, right? What do you intend to do with this min expression?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(1) tma_tile_inner can be smaller than 32, e.g. when input size is 17x16, then tma_domain_inner = 16 we will use tma_tile_inner = 16/2 = 8. Then bdimx = 8. It doesn't make sense to use more than 8 threads to process 8 elements. For most cases it will be multiple of 32 and bdimx = 32.
(2) Use bdimx = 32 has several benefits:
(a) Avoid bank conflicts, e.g. assume tma_tile_inner =32 if bdimx = 16, there is a 2-way conflict.

(b) Why not use more than 32? Want to leave space for vectorization. The largest tma_tile_inner = 256. To process these 256 elements and write to gmem, using 32 threads lead to each thread processing 8 elements and then vectorized write to global memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can use vectorized load from smem to regs to avoid bank conflict when bdimx = 8/16, this para can be tuned.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Please add this to the code as a comment.
Co-authored-by: Naoya Maruyama <[email protected]>
Co-authored-by: Naoya Maruyama <[email protected]>
Co-authored-by: Naoya Maruyama <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 1 comment
csrc/scheduler/pointwise_tma.cpp
Outdated
| const int64_t elements_per_cta = scheduler_utils::roundUpToN(ceilDiv(bits_per_cta, bits_per_element), 1024); | ||
| elements_per_cta = scheduler_utils::roundUpToN(elements_per_cta, 1024); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: compilation error: elements_per_cta is declared as const on line 130, but line 131 attempts to reassign it. Remove line 131 since the value is already rounded up to 1024 on line 130.
| const int64_t elements_per_cta = scheduler_utils::roundUpToN(ceilDiv(bits_per_cta, bits_per_element), 1024); | |
| elements_per_cta = scheduler_utils::roundUpToN(elements_per_cta, 1024); | |
| const int64_t elements_per_cta = scheduler_utils::roundUpToN(ceilDiv(bits_per_cta, bits_per_element), 1024); |
| params->lparams.bindUnsafe(bdimx, ParallelType::TIDx); | ||
| params->lparams.bindUnsafe(bdimy, ParallelType::TIDy); | ||
|
|
||
| // ========== Step 5: Determine Vectorization Factor ========== |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please comment why vectorization not TMA and what tensors are vectorized.
IIUC, we are not using the usual vectorization analysis? Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will revise comments
- It's a heuristic para, we can do either vectorized store or TMA store or both (one output using TMA, the other using vectorized store)
- No need to further analsysis since tma tile inner size is already selected to be a multiple of 16 bytes.
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, no comments
Following #5529
Add a basic tma pointwise scheduler without considering broadcast, reshapes, contiguity, allocation domains, etc.
When to use tma:
(1) Hopper and newer GPUs
(2) Has input tensor suitable for TMA load, only based on dtype and element count.
Heuristics:
(1) Inner TMA domain size: set by
getInnerTmaDomainSize()with a target of 512.(2) TMA Tile Size: Assuming 8 CTAs per SM, derive elements per CTA to meet required bytes-in-flight, then split into
outer tma tileandinner tma tile.(3) Block tile size: 32 threads in x-dim and 4 threads in y-dim
Schedule:
Same as previous manual schedule process.