Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Jan 6, 2026

Behavior before this PR:
Inner outer persistent scheduler always prioritize using shared memory to store persistent buffers.

Issue:
When cached inputs are not persistent buffers, the data flow is gmem (input) --> regs (cached input) --> smem (persistent buffer) --> regs (for computations). The path of regs-> smem -> regs is redundant.

It only makes sense to use smem persistent when the cached inputs are persistent buffer in which case the data flow is: gmem -- (TMA or CpAsync)--> smem (persistent buffer) --> regs (for computations)

After this PR
Only prioritize using shared memory to store persistent buffers when cached inputs are persistent buffer.
The performance of the newly added test is increased from 50.5% SOL to 61.6% SOL on GB200.

@liqiangxl
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Jan 6, 2026

Review updated until commit a0797d3

Description

  • Optimize persistent buffer storage by prioritizing shared memory only when cached inputs are persistent

  • Avoid inefficient data path: gmem → regs → smem → regs when cached inputs aren't persistent

  • Enable efficient CpAsync/TMA direct copies from global memory to shared memory for persistent cached inputs

  • Add test case demonstrating 11.1% SOL performance improvement (50.5% → 61.6% on GB200)

Changes walkthrough

Relevant files
Enhancement
normalization_inner_outer_utils.cpp
Conditional shared memory allocation for persistent buffers

csrc/scheduler/normalization_inner_outer_utils.cpp

  • Added tracking of total register buffer size alongside shared memory
    size
  • Implemented logic to detect if cached inputs are persistent
    (project_to_input or all fusion inputs)
  • Modified buffer allocation to prefer shared memory only for persistent
    cached inputs
  • Updated smem_persistent_buffers assignment to conditional on cached
    input persistence
  • +20/-2   
    Tests
    test_combined_inner_outer_reduction.cpp
    Test case for non-persistent cached inputs optimization   

    tests/cpp/test_combined_inner_outer_reduction.cpp

  • Added test case CachedInputsAreNotPersistentFusedReshape to validate
    new behavior
  • Test creates fusion with reshape operations and validates scheduler
    performance
  • Documents expected performance improvement from 50.5% to 61.6% SOL on
    GB200
  • +31/-0   

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    🔒 No security concerns identified
    ⚡ Recommended focus areas for review
    Logic correctness

    The logic for determining cached_inputs_are_persistent uses std::all_of to check if all buffers are fusion inputs. This might be too restrictive - there could be cases where some buffers are fusion inputs and others are intermediate tensors that should still be treated as persistent. Consider if this condition accurately captures all scenarios where shared memory should be preferred.

    std::all_of(buffers.begin(), buffers.end(), [](TensorView* tv) {
                                      return tv->isFusionInput();
                                    });
    Performance validation

    While the performance improvement from 50.5% to 61.6% SOL is significant, it would be valuable to validate this improvement across different GPU architectures and tensor sizes to ensure the optimization is robust and not specific to the GB200/test case.

    // Prefer shared memory persistent buffers when cached inputs are persistent,
    // enabling efficient CpAsync or TMA direct copies from global memory to
    // shared memory. Otherwise, register buffers are preferred to avoid the
    // inefficient data path: gmem (input) → regs (cached input) → smem
    // (persistent buffer) → regs (computation), where the regs → smem → regs path
    // is redundant.
    bool cached_inputs_are_persistent = buffer_params.project_to_input ||
        std::all_of(buffers.begin(), buffers.end(), [](TensorView* tv) {
                                          return tv->isFusionInput();
                                        });
    if (cached_inputs_are_persistent) {
      buffer_params.smem_buffer_size_bit += total_smem_buffer_size_bit;
    } else {
      buffer_params.regs_buffer_size_bit += total_regs_buffer_size_bit;
    }

    Test failures

    • (High, 5) CUDA misaligned address in multidevice Transformer backward test (test_transformer_backward[SEQUENCE_PARALLEL])

      Test Name A100 GB200 GB200 (dist.) H100 H100 (dist.) Source
      tests.python.multidevice.test_transformer.test_transformer_backward[SEQUENCE_PARALLEL]
    • (High, 4) CUBLAS_STATUS_EXECUTION_FAILED in nvFuser transformer_backward (test_transformer_backward[TENSOR_PARALLEL]) across multiple runners

      Test Name GB200 GB200 (dist.) H100 H100 (dist.) Source
      tests.python.multidevice.test_transformer.test_transformer_backward[TENSOR_PARALLEL]
    • (High, 1) CUDA misaligned address in Flash Attention backward during multidevice transformer (A100, test_transformer_backward)

      Test Name A100 Source
      tests.python.multidevice.test_transformer.test_transformer_backward[TENSOR_PARALLEL]

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    1 file reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @liqiangxl liqiangxl force-pushed the llu/opt_inner_outer branch from 1631b53 to 8cfa421 Compare January 6, 2026 18:20
    @NVIDIA NVIDIA deleted a comment from greptile-apps bot Jan 6, 2026
    @liqiangxl
    Copy link
    Collaborator Author

    @greptileai review this PR

    @liqiangxl liqiangxl changed the title prioritze using smem only when project to inputs is true prioritze using smem only when cached inputs are persistent Jan 6, 2026
    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Jan 6, 2026

    Greptile Summary

    Optimizes persistent buffer allocation by preferring registers over shared memory when cached inputs aren't persistent buffers, avoiding the inefficient regs → smem → regs data path. When cached inputs are persistent (fusion inputs or projectable), shared memory is still used to enable efficient CpAsync/TMA direct copies from global memory.

    Key Changes:

    • Added cached_inputs_are_persistent flag checking if project_to_input=true or all buffers are fusion inputs
    • Conditionally allocates persistent buffers to registers vs shared memory based on this flag
    • Test demonstrates 50.5% → 61.6% SOL performance improvement on GB200 for reshape+reduction patterns

    Critical Issue Found:

    • The fallback allocation logic (lines 269-315) assumes all buffers start in shared memory, but with this change they may start in registers when cached_inputs_are_persistent=false
    • This causes smem_buffer_size_bit to become negative when the fallback path subtracts buffer sizes from zero
    • The path is triggered when initial resource allocation exceeds limits

    Confidence Score: 2/5

    • This PR has a critical logic bug in the fallback allocation path that could produce incorrect memory calculations
    • The optimization logic is sound and demonstrates real performance improvements, but the fallback path at lines 269-315 doesn't account for the new case where buffers start in registers instead of shared memory. When cached_inputs_are_persistent=false and resources are constrained, the code will subtract from a zero smem_buffer_size_bit, producing negative values and incorrect allocation decisions. This bug may not surface in the test case if resources are sufficient, but could cause failures in constrained scenarios.
    • Pay close attention to csrc/scheduler/normalization_inner_outer_utils.cpp lines 269-315 (fallback allocation logic)

    Important Files Changed

    Filename Overview
    csrc/scheduler/normalization_inner_outer_utils.cpp Adds conditional logic to prefer registers over shared memory when cached inputs aren't persistent, but fallback path at lines 269-315 assumes buffers start in smem which breaks when cached_inputs_are_persistent=false
    tests/cpp/test_combined_inner_outer_reduction.cpp Adds test for reshape+reduction scenario where cached inputs aren't persistent, validates the optimization with good performance improvement

    Sequence Diagram

    sequenceDiagram
        participant Scheduler as Inner Outer Scheduler
        participant Utils as getPersistentBufferStorageParams
        participant Buffers as Persistent Buffers
        participant Memory as Memory Allocation
    
        Scheduler->>Utils: Request buffer storage params
        Utils->>Utils: Determine if cached_inputs_are_persistent
        alt cached_inputs_are_persistent=true (buffers are fusion inputs or project_to_input)
            Note over Utils: OLD & NEW: Use smem path
            Utils->>Buffers: Allocate to smem_buffer_size_bit
            Buffers-->>Memory: gmem → (CpAsync/TMA) → smem → regs
        else cached_inputs_are_persistent=false (NEW BEHAVIOR)
            Note over Utils: NEW: Use registers path
            Utils->>Buffers: Allocate to regs_buffer_size_bit
            Buffers-->>Memory: gmem → regs (cached) → regs (computation)
            Note over Memory: Avoids redundant regs→smem→regs path
        end
        
        alt Resources fit
            Utils->>Scheduler: Return with has_enough_regs_and_smem=true
        else Resources don't fit
            Note over Utils: ISSUE: Fallback assumes buffers in smem
            Utils->>Utils: Move buffers from smem to regs
            Note over Utils: BUG: When cached_inputs_are_persistent=false,<br/>smem_buffer_size_bit=0, subtraction produces negatives
            Utils->>Scheduler: Return allocation result
        end
    
    Loading

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Additional Comments (1)

    1. csrc/scheduler/normalization_inner_outer_utils.cpp, line 269-284 (link)

      logic: Fallback logic assumes all buffers start in shared memory (line 275 comment), but when cached_inputs_are_persistent=false (line 246), buffers are placed in registers and smem_buffer_size_bit=0. Line 284 subtracts from smem_buffer_size_bit, which will produce negative values. Add early return or skip this fallback path when !cached_inputs_are_persistent:

      Is there a hybrid allocation strategy needed when cached_inputs_are_persistent=false but resources are insufficient?

    2 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @liqiangxl liqiangxl marked this pull request as draft January 8, 2026 16:25
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants