Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Nov 18, 2025

Following #5529
Add a basic tma pointwise scheduler without considering broadcast, reshapes, contiguity, allocation domains, etc.

When to use tma:
(1) Hopper and newer GPUs
(2) Has input tensor suitable for TMA load, only based on dtype and element count.

Heuristics:
(1) Inner TMA domain size: set by getInnerTmaDomainSize() with a target of 512.
(2) TMA Tile Size: Assuming 8 CTAs per SM, derive elements per CTA to meet required bytes-in-flight, then split into outer tma tile and inner tma tile.
(3) Block tile size: 32 threads in x-dim and 4 threads in y-dim

Schedule:
Same as previous manual schedule process.

@github-actions
Copy link

github-actions bot commented Nov 18, 2025

Review updated until commit ba3543a

Description

  • Add automatic TMA (Tensor Memory Accelerator) pointwise scheduler for Hopper+ GPUs

  • Implement TMA domain splitting with 2D tile structure for efficient memory access

  • Add heuristics for TMA tile sizing, thread block configuration, and vectorization

  • Include fallback mechanism to non-TMA scheduler when TMA is not suitable

  • Add comprehensive test coverage for TMA scheduling with various tensor dimensions

Changes walkthrough

Relevant files
Enhancement
pointwise_tma.cpp
Complete TMA scheduler implementation                                       

csrc/scheduler/pointwise_tma.cpp

  • Complete TMA scheduler implementation with detailed documentation
  • Implement getPointwiseHeuristics() with TMA domain/tile sizing logic
  • Implement schedulePointwise() with 8-phase TMA scheduling pipeline
  • Add TMA terminology guide and comprehensive inline comments
  • Handle TMA vs LDG input classification and parallelization strategy
  • +406/-4 
    pointwise.cpp
    Integrate TMA scheduler into pointwise scheduler                 

    csrc/scheduler/pointwise.cpp

  • Add mayHaveTmaCompatibleInputs() and mayUseTma() helper functions
  • Modify computeHeuristics() to check TMA eligibility and call TMA
    scheduler
  • Add fallback to non-TMA scheduler when TMA is not applicable
  • Include ATen CUDA context header for hardware capability checking
  • +63/-4   
    utils.cpp
    Rename and enhance TMA domain utility function                     

    csrc/scheduler/utils.cpp

  • Rename getInnerTmaDomainSize() to getTmaDomainInner() for clarity
  • Update parameter from min_dtype_bytes to min_dtype_bits
  • Enhance documentation with detailed TMA constraints explanation
  • Improve divisor search algorithm for optimal TMA domain sizing
  • +60/-34 
    pointwise_heuristic.h
    Extend PointwiseParams with TMA configuration                       

    csrc/scheduler/pointwise_heuristic.h

  • Add TMA configuration fields to PointwiseParams class
  • Update equality check, toString(), and hash() for TMA fields
  • Add detailed documentation for TMA domain and tile parameters
  • +39/-12 
    pointwise_tma.h
    Update TMA scheduler header interface                                       

    csrc/scheduler/pointwise_tma.h

  • Update getPointwiseHeuristics() signature to include
    FusionRuntimeProperties
  • Maintain schedulePointwise() function declaration
  • +2/-1     
    Tests
    test_pointwise.cpp
    Add comprehensive TMA scheduler tests                                       

    tests/cpp/test_pointwise.cpp

  • Add Tma2dTileTest parameterized test class for TMA functionality
  • Test auto-scheduling path and manual TMA scheduling
  • Validate TMA usage detection and correctness across dimensions
  • Add test parameters for dim0, ndims, use_tma_store, and auto_schedule
  • +85/-34 
    Documentation
    utils.h
    Update TMA utility function documentation                               

    csrc/scheduler/utils.h

  • Update getTmaDomainInner() documentation and parameter names
  • Change min_dtype_bytes parameter to min_dtype_bits
  • Add comprehensive TMA terminology and constraint explanations
  • +29/-28 
    Configuration changes
    utils.h
    Add TMA hardware limit constants                                                 

    csrc/utils.h

  • Add kMaxElementsPerTmaBoxDim and kMaxElementsPerTmaTileDim constants
  • Document TMA box vs tile terminology distinction
  • +8/-0     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Division by Zero Risk

    In getPointwiseHeuristics() at line 126, there's a potential division by zero issue. The code computes elements_per_cta using bits_per_element before checking if bits_per_element is 0. While there's a subsequent check at lines 127-129, the division occurs first, which could cause undefined behavior if bits_per_element is 0.

    const int64_t bits_per_element = getInputBitsPerElement(prop);
    if (bits_per_element == 0) {
      return nullptr;
    }
    Silent Failure in TMA Compatibility Check

    In mayHaveTmaCompatibleInputs() at lines 356-359, when elem_count fails the TMA compatibility checks, the function silently continues to the next iteration without any logging or error indication. This makes debugging TMA scheduling issues difficult, as there's no visibility into why TMA was not used.

    if (elem_count % tma_domain_inner_min != 0 ||
        elem_count == tma_domain_inner_min) {
      continue;
    }
    Incomplete TMA Store Implementation

    The PR sets use_tma_store = false at line 219 and doesn't implement TMA store functionality, but the documentation and test framework suggest TMA store support was intended. This creates a disconnect between the API and implementation, and tests may be expecting TMA store functionality that doesn't exist.

    params->use_tma_store = false;

    @liqiangxl liqiangxl changed the base branch from main to llu/pt2_utils November 18, 2025 15:54
    @liqiangxl liqiangxl marked this pull request as ready for review November 18, 2025 16:05
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl liqiangxl requested review from naoyam and tbqh November 18, 2025 16:06
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Nov 18, 2025

    Greptile Overview

    Greptile Summary

    This PR adds a TMA (Tensor Memory Accelerator) pointwise scheduler for Hopper and newer GPUs. The implementation includes preliminary TMA compatibility checks, heuristic computation for determining TMA domain and tile sizes, and a complete scheduling pipeline that transforms tensors through a 2D TMA domain hierarchy.

    Key changes:

    • Added mayUseTma() check requiring GPU SM >= 9.0 and TMA-compatible inputs
    • Implemented heuristics computing TMA domain split (targeting ~512 inner elements) and tile sizes based on memory bandwidth requirements
    • Created scheduling logic that handles both TMA tensors (using Bulk parallelization) and non-TMA tensors (using thread parallelization)
    • Renamed getInnerTmaDomainSize() to getTmaDomainInner() with parameter changes from bytes to bits
    • Added comprehensive test coverage for various input sizes and dimensionalities
    • Includes proper fallback to non-TMA scheduler when TMA is not applicable

    Limitations (as documented):

    • Does not yet handle broadcasts, reshapes, contiguity checks, or allocation domains
    • Only supports tensors with matching dimensionality to the reference tensor

    Confidence Score: 4/5

    • This PR is safe to merge with the understanding that it's a basic TMA scheduler implementation
    • The implementation is well-documented and includes proper error handling, fallback mechanisms, and comprehensive tests. Previous critical issues (division by zero, tma_tile_outer=0) have been flagged in comments. The code acknowledges its limitations regarding broadcasts, reshapes, and contiguity. Confidence reduced from 5 to 4 only because those previous logic issues in comments should be verified as resolved.
    • csrc/scheduler/pointwise_tma.cpp - Verify that all previously flagged logic issues have been resolved

    Important Files Changed

    File Analysis

    Filename Score Overview
    csrc/scheduler/pointwise_tma.cpp 4/5 Implements TMA pointwise scheduler with heuristics and scheduling logic. Previous syntax issues (typos) and logic issues (division by zero, tma_tile_outer=0) have been addressed in comments.
    csrc/scheduler/pointwise.cpp 5/5 Adds TMA compatibility checks and fallback logic. Clean integration with existing pointwise scheduler.
    csrc/scheduler/utils.cpp 5/5 Renamed function with improved documentation and changed bytes to bits for consistency. Logic remains sound.
    tests/cpp/test_pointwise.cpp 5/5 Comprehensive test suite for TMA scheduler, including both manual and auto-schedule paths with extensive parameter combinations.

    Sequence Diagram

    sequenceDiagram
        participant User
        participant PointwiseScheduler
        participant TMACheck as mayUseTma()
        participant TMAHeuristics as getPointwiseHeuristics()
        participant TMASchedule as schedulePointwise()
        participant NonTMA as non_tma scheduler
    
        User->>PointwiseScheduler: computeHeuristics()
        PointwiseScheduler->>TMACheck: Check TMA compatibility
        TMACheck->>TMACheck: GPU >= Hopper (SM 9.0)?
        TMACheck->>TMACheck: mayHaveTmaCompatibleInputs()?
        alt TMA Compatible
            TMACheck-->>PointwiseScheduler: true
            PointwiseScheduler->>TMAHeuristics: Compute TMA heuristics
            TMAHeuristics->>TMAHeuristics: Compute TMA domain [outer, inner]
            TMAHeuristics->>TMAHeuristics: Calculate elements_per_cta
            TMAHeuristics->>TMAHeuristics: Determine TMA tile size [to, ti]
            TMAHeuristics->>TMAHeuristics: Configure thread blocks (bdimx, bdimy)
            TMAHeuristics->>TMAHeuristics: Set vectorization factor
            TMAHeuristics-->>PointwiseScheduler: PointwiseParams (TMA)
            alt TMA params valid
                PointwiseScheduler->>User: Return TMA params
                User->>TMASchedule: schedule(fusion, params)
                TMASchedule->>TMASchedule: Flatten and merge dims
                TMASchedule->>TMASchedule: Split into TMA domain [Do, Di]
                TMASchedule->>TMASchedule: Split into tiles [Do/to, to, Di/ti, ti]
                TMASchedule->>TMASchedule: Parallelize TMA tensors (Bulk)
                TMASchedule->>TMASchedule: Parallelize non-TMA tensors (TIDx, TIDy)
                TMASchedule->>TMASchedule: Apply vectorization
                TMASchedule-->>User: Scheduled fusion
            else TMA params invalid (bits_per_element=0)
                TMAHeuristics-->>PointwiseScheduler: nullptr
                PointwiseScheduler->>NonTMA: Fallback to non-TMA
                NonTMA-->>PointwiseScheduler: Non-TMA params
                PointwiseScheduler->>User: Return non-TMA params
            end
        else Not TMA Compatible
            TMACheck-->>PointwiseScheduler: false
            PointwiseScheduler->>NonTMA: Use non-TMA scheduler
            NonTMA-->>PointwiseScheduler: Non-TMA params
            PointwiseScheduler->>User: Return non-TMA params
        end
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, 6 comments

    Edit Code Review Agent Settings | Greptile
    React with 👍 or 👎 to share your feedback on this new summary format

    prop.min_dtype_size_bit_for_vectorization);
    NVF_ERROR(
    tma_domain_inner > 1 && prop.n_elems % tma_domain_inner == 0,
    "Ilegal TMA inner domain size: ",
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: typo: "Ilegal" should be "Illegal"

    Suggested change
    "Ilegal TMA inner domain size: ",
    "Illegal TMA inner domain size: ",

    // elements of this TMA load operation. Assuming 8 CTAs per SM, using
    // empirical required bits in flight, it is just a guidance, actual tile size
    // is set by tma_tile_inner and tma_tile_outer.
    // Inner tila size: ensure we have at least 2 tiles in the inner TMA
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: typo: "tila" should be "tile"

    Suggested change
    // Inner tila size: ensure we have at least 2 tiles in the inner TMA
    // Inner tile size: ensure we have at least 2 tiles in the inner TMA

    continue;
    }
    // function-condition-2, We only do 2D TMA, requires at least 2 boxes in
    // inner dimension each with 16 bytes. This requires a minimum innter tma
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: typo: "innter" should be "inner"

    Suggested change
    // inner dimension each with 16 bytes. This requires a minimum innter tma
    // domain size of 2 * 16 bytes. We also should skip if the inner tma domain

    // function-condition-3, reshape, contiguity, allocation domain, etc.
    // TODO: performance checks
    // performance-condition-1, input size is too small
    // performance-condition-2, Innner TMA domain size is too small
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: typo: "Innner" should be "Inner"

    Suggested change
    // performance-condition-2, Innner TMA domain size is too small
    // performance-condition-2, Inner TMA domain size is too small

    // check does not guarantee that TMA will be used, as the actual TMA usage will
    // be determined by the heuristics.
    bool mayUseTma(const pointwise_utils::FusionRuntimeProperties& prop) {
    // Harware, Don't use tma for pre-Blackwell GPUs
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: typo: "Harware" should be "Hardware"

    Suggested change
    // Harware, Don't use tma for pre-Blackwell GPUs
    // Hardware, Don't use tma for pre-Blackwell GPUs

    Comment on lines 83 to 84
    int64_t bits_per_element = getInputBitsPerElement(prop);
    int64_t elements_per_cta = ceilDiv(bits_per_cta, bits_per_element);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: potential division by zero if no TMA-compatible inputs exist - bits_per_element could be 0, causing division by zero in ceilDiv

    Suggested change
    int64_t bits_per_element = getInputBitsPerElement(prop);
    int64_t elements_per_cta = ceilDiv(bits_per_cta, bits_per_element);
    int64_t bits_per_element = getInputBitsPerElement(prop);
    NVF_ERROR(bits_per_element > 0, "No TMA-compatible inputs found");
    int64_t elements_per_cta = ceilDiv(bits_per_cta, bits_per_element);

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile
    React with 👍 or 👎 to share your feedback on this new summary format

    Comment on lines 94 to 95
    int64_t tma_tile_outer =
    std::min(elements_per_cta / tma_tile_inner, max_tma_tile_outer);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: tma_tile_outer could be 0 if elements_per_cta < tma_tile_inner (integer division rounds down)

    Suggested change
    int64_t tma_tile_outer =
    std::min(elements_per_cta / tma_tile_inner, max_tma_tile_outer);
    int64_t tma_tile_outer =
    std::max(1L, std::min(elements_per_cta / tma_tile_inner, max_tma_tile_outer));

    auto total_bits = elem_count * dtype_bits;

    // Condition 1: TMA requires size divisible by 16 bytes (128 bits)
    if (total_bits % 128 != 0) {
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Doesn't the multi-dimensional TMA support predication?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Yes, it supports.
    The comment is not clear. The first step is creaing a 2D TMA domain, which requires divisible split and inner dim must be multiple-16 bytes. This condition can be merged with the following condition-2.

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile
    React with 👍 or 👎 to share your feedback on this new summary format

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile
    React with 👍 or 👎 to share your feedback on this new summary format

    Comment on lines +378 to +380
    // Hardware requirement: Don't use TMA for pre-Hopper GPUs
    if (at::cuda::getCurrentDeviceProperties()->major < 9) {
    return false;
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: PR description says "Blackwell GPUs" but code checks major < 9 which enables TMA for Hopper (SM 9.0) and above, not Blackwell (SM 10.0) and above. Check the intended GPU requirement.

    Comment on lines 98 to 99
    int64_t tma_tile_outer =
    std::min(elements_per_cta / tma_tile_inner, max_tma_tile_outer);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: tma_tile_outer will be 0 if elements_per_cta < tma_tile_inner (integer division rounds down), causing split(0) failure at line 214 and potentially bdimy=0 at line 108.

    Suggested change
    int64_t tma_tile_outer =
    std::min(elements_per_cta / tma_tile_inner, max_tma_tile_outer);
    int64_t tma_tile_outer = std::max(
    1L, std::min(elements_per_cta / tma_tile_inner, max_tma_tile_outer));

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile
    React with 👍 or 👎 to share your feedback on this new summary format

    Base automatically changed from llu/pt2_utils to main November 19, 2025 00:58
    liqiangxl and others added 2 commits November 18, 2025 19:59
    Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    4 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile
    React with 👍 or 👎 to share your feedback on this new summary format

    const auto& prop = prop_opt.value();

    bool use_tma = false;
    bool use_tma = mayUseTma(prop) && isOptionEnabled(EnableOption::TmaPointwise);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Why not structure it in the following way instead:

    Move the isOptionEnabled() check inside mayUseTma():

    if (!isOptionEnabled(EnableOption::TmaPointwise)) {
      return false;
    }
    

    Then in this function, you only need to write:

    if (mayUseTma(prop)) {
      pparams = pointwise::tma::getPointwiseHeuristics(
          fusion, runtime_info, data_cache, prop);
    }
    

    This is slightly more concise. But more importantly, it moves the isOptionEnabled() check next to the getCurrentDeviceProperties() check, which seems similar.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Thanks for the suggestion. I prefer keeping the current approach becuase of the following reason:
    mayUseTma(prop) says we can use tma-version, isOptionEnabled(EnableOption::TmaPointwise) says user want to use tma-version. So they are different things.

    In the future, bool use_tma = mayUseTma(prop) && ; will be revised to bool use_tma = mayUseTma(prop) || isOptionEnabled(EnableOption::TmaPointwise); clearly passing the msg that tma version will be used because of one of the following conditions:

    • (1) heuristic auto selection mayUseTma(prop)
    • (2) User select to use it

    const pointwise_utils::FusionRuntimeProperties& prop) {
    // Hardware constants
    constexpr int64_t threads_per_warp = 32;
    constexpr int64_t max_size_per_tma_tile_dim = 256;
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    What does this represent?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    256 is the maximum elements per TMA dimension. Needs a better var name ?

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Ah, I see. It wasn't immediately clear to me what the TMA tile dim means.

    Comment on lines 90 to 91
    int64_t max_tma_tile_inner =
    std::min(tma_domain_inner / 2, max_size_per_tma_tile_dim);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    What is this? Why is tma_domain_inner divided by 2?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This is similar to what we did in getInnerTmaDomainSize, we want to keep at least 2 boxes/tiles in the inner dimension to avoid creating two contiguous bulk dimensions which will be merged and may lead to a bulk size larger than 256.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I realized you have this in the above:

    // - Inner tile size: ensure at least 2 tiles in the inner TMA dimension

    I still easily get lost with various TMA "XYZ", e.g., tiles, dimensions, domains, boxes, etc. I'd appreciate if you could be more strict with the naming. For example:

    // - Inner tile size: ensure at least 2 tiles in the inner TMA domain

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I got your point. Let me revise these var names and comments.

    Comment on lines 94 to 97
    int64_t tma_tile_inner = std::min(tma_domain_inner / 2, threads_per_warp);
    while (tma_tile_inner * 2 <= max_tma_tile_inner) {
    tma_tile_inner *= 2;
    }
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I'm lost here again. Can you please give a comment to each variable here?

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Revised variable names and comments:

    Variable Name Changes Summary

    Function Names

    • getInnerTmaDomainSize()getTmaDomainInner()
      • More concise, matches the variable it computes

    Function Parameters

    • target_inner_tma_domain_sizetma_domain_inner_target

      • Reordered to match pattern: <variable>_<modifier>
    • min_dtype_bytesmin_dtype_bits

      • Changed from bytes to bits for consistency with codebase
      • Default value: 1 byte → 8 bits

    Local Variables (in pointwise_tma.cpp)

    • max_tma_tile_innertma_tile_inner_max
    • max_tma_tile_outertma_tile_outer_max
      • Reordered for consistency: <variable>_max pattern

    Terminology Standardization

    Throughout all files, consistently use:

    • tma_domain_outer and tma_domain_inner (not OuterTmaDomain, Do, or tma_outer_domain_size)
    • tma_tile_outer and tma_tile_inner (not to, ti, or abbreviated forms)

    Rationale: All changes improve consistency by following the pattern <concept>_<dimension>_<modifier> (e.g., tma_tile_inner_max), making the code more readable and searchable.

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines +24 to +52
    // ============================================================================
    // TMA TERMINOLOGY GUIDE:
    // ============================================================================
    // The TMA scheduler uses a 2-level hierarchy to organize data:
    //
    // 1. TMA DOMAIN: The logical split of the entire problem space into two parts:
    // - tma_domain_inner: Size of the inner (contiguous) dimension
    // - tma_domain_outer: Size of the outer dimension (n_elems /
    // tma_domain_inner)
    // - requirement: n_elems % tma_domain_inner == 0
    // Example: For 2048 elements with tma_domain_inner=512:
    // TMA Domain structure: [4, 512]
    //
    // 2. TMA TILE: The actual hardware tile size for each TMA load operation
    // - tma_tile_inner: Number of elements along the inner dimension per tile
    // - tma_tile_outer: Number of elements along the outer dimension per tile
    // Example: For TMA domain [4, 512] with tiles [2, 128]:
    // Each TMA load fetches a [2 x 128] tile, requiring 2 x 4 = 8 tiles total
    //
    // Note: In general TMA terminology, a "box" is the dense rectangular region
    // loaded from global memory, while a "tile" is a potentially strided subset
    // selected from that box. The pointwise scheduler always uses dense tiles
    // (tile = box), so we use "TMA tile" throughout to refer to both concepts.
    //
    // Transformation sequence: logical domain -> TMA domain -> TMA tile
    // [I0, I1, ...] -- > [I] -> [Do, Di] -> [Do/to, to, Di/ti, ti]
    // where Do=tma_domain_outer, Di=tma_domain_inner,
    // to=tma_tile_outer, ti=tma_tile_inner
    // ============================================================================
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Thanks. This is super helpful!

    params->use_tma_load = true;
    NVF_THROW("Schedule pointwise using TMA");

    // ========== Step 1: Compute TMA Domain Dimensions ==========
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    This also helps a lot. Thanks!

    Comment on lines 172 to 173
    // process each TMA tile. Threads cooperate to move data from shared memory
    // to registers and perform computation.
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    I appreciate detailed comments, but this is nothing specific to TMA and, yeah, you can assume everyone involved in nvFuser would know what shared memory is.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Instead, could you explain the relationship between TMA tiles and thread blocks? For example, I'm not sure why blockDim.x is the minimum of 32 and the inner tile size. tma_tile_inner is a multiple of 32 per lines 157-159, so it seems bdimx is always 32, right? What do you intend to do with this min expression?

    Copy link
    Collaborator Author

    @liqiangxl liqiangxl Nov 21, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    (1) tma_tile_inner can be smaller than 32, e.g. when input size is 17x16, then tma_domain_inner = 16 we will use tma_tile_inner = 16/2 = 8. Then bdimx = 8. It doesn't make sense to use more than 8 threads to process 8 elements. For most cases it will be multiple of 32 and bdimx = 32.

    (2) Use bdimx = 32 has several benefits:
    (a) Avoid bank conflicts, e.g. assume tma_tile_inner =32 if bdimx = 16, there is a 2-way conflict.
    image

    (b) Why not use more than 32? Want to leave space for vectorization. The largest tma_tile_inner = 256. To process these 256 elements and write to gmem, using 32 threads lead to each thread processing 8 elements and then vectorized write to global memory.

    Copy link
    Collaborator Author

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    We can use vectorized load from smem to regs to avoid bank conflict when bdimx = 8/16, this para can be tuned.

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Thanks. Please add this to the code as a comment.

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    Comment on lines 130 to 131
    const int64_t elements_per_cta = scheduler_utils::roundUpToN(ceilDiv(bits_per_cta, bits_per_element), 1024);
    elements_per_cta = scheduler_utils::roundUpToN(elements_per_cta, 1024);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    syntax: compilation error: elements_per_cta is declared as const on line 130, but line 131 attempts to reassign it. Remove line 131 since the value is already rounded up to 1024 on line 130.

    Suggested change
    const int64_t elements_per_cta = scheduler_utils::roundUpToN(ceilDiv(bits_per_cta, bits_per_element), 1024);
    elements_per_cta = scheduler_utils::roundUpToN(elements_per_cta, 1024);
    const int64_t elements_per_cta = scheduler_utils::roundUpToN(ceilDiv(bits_per_cta, bits_per_element), 1024);

    params->lparams.bindUnsafe(bdimx, ParallelType::TIDx);
    params->lparams.bindUnsafe(bdimy, ParallelType::TIDy);

    // ========== Step 5: Determine Vectorization Factor ==========
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Please comment why vectorization not TMA and what tensors are vectorized.

    IIUC, we are not using the usual vectorization analysis? Why?

    Copy link
    Collaborator Author

    @liqiangxl liqiangxl Nov 22, 2025

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Will revise comments

    1. It's a heuristic para, we can do either vectorized store or TMA store or both (one output using TMA, the other using vectorized store)
    2. No need to further analsysis since tma tile inner size is already selected to be a multiple of 16 bytes.

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    8 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    4 participants