Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

Check broadcast interference with TMA loads

Case 1: Broadcast before TMA load (in gmem tensor logical domain)

  • Already captured by existing TMA lowering pass
  • scheduler should avoid merging bcast and non-bcast domains.

Case 2: Broadcast after TMA load (downstream of TMA-loaded tensor)

  • New validation added in validateTMAConsumerBroadcasts()
  • Prevents broadcast dimensions from being merged with non-broadcast dimensions

Root Cause:
TMA auto-fills out-of-bounds accesses with zeros. When broadcast dimensions
are merged with non-broadcast and loaded with TMA, TMA
treats broadcast as a physical tile dimension and loads extra rows/columns
as zeros, breaking broadcast semantics (which should replicate values, not
fill zeros).

Testing:
See test BroadcastDownstreamOfTMALoad for detailed example showing the
numerical error when this restriction is violated.

@liqiangxl
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Nov 19, 2025

Review updated until commit de72855

Description

  • Added validation to prevent broadcast dimensions from interfering with TMA loads

  • Implemented validateTMAConsumerBroadcasts() function using IdModel graph analysis

  • Added MMA tracking via has_mma_ flag to skip validation when MMA operations present

  • Added comprehensive test cases covering broadcast interference scenarios

Changes walkthrough

Relevant files
Enhancement
tma.cpp
Add broadcast validation for TMA loads                                     

csrc/device_lower/analysis/tma.cpp

  • Added validateTMAConsumerBroadcasts() function to check broadcast
    interference with TMA
  • Modified getTMAInfo() to call validation when no MMA operations
    present
  • Uses IdModel PERMISSIVE mode to detect broadcast-to-bulk dependencies
  • +78/-2   
    device_version.cpp
    Track MMA operations in GpuLower                                                 

    csrc/device_lower/analysis/device_version.cpp

  • Added GpuLower::current()->setHasMma(true) in handle(MmaOp* mma_op)
  • Tracks presence of MMA operations for TMA validation logic
  • +1/-0     
    lower2device.h
    Add MMA tracking to GpuLower class                                             

    csrc/device_lower/lower2device.h

  • Added has_mma_ boolean member variable
  • Added hasMma() getter and setHasMma(bool) setter methods
  • Provides MMA tracking capability for TMA validation
  • +11/-0   
    Tests
    test_memory.cpp
    Add comprehensive TMA broadcast interference tests             

    tests/cpp/test_memory.cpp

  • Added BroadcastInGmemAllocationDomain test case
  • Added BroadcastDownstreamOfTMALoad test case (2D tile scenario)
  • Added BroadcastDownstreamOfTMALoad1dTile test case (1D tile scenario)
  • Added BroadcastDownstreamOfTMALoad1dTileValid test case (valid
    scenario)
  • All tests verify proper error handling for broadcast interference
  • +271/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Broadcast validation logic

    The new validateTMAConsumerBroadcasts() function uses PERMISSIVE mode to map broadcast to non-broadcast dimensions. Need to verify this approach correctly identifies all problematic broadcast patterns and doesn't produce false positives/negatives. The ValGroup traversal logic should be thoroughly tested with various broadcast scenarios.

    void validateTMAConsumerBroadcasts(TensorView* smem_tv) {
      // Check if Bulk-parallelized dimensions (TMA tile dimensions) depend on
      // broadcast dimensions through transformations.
      //
      // Uses PERMISSIVE mode which maps broadcast to non-broadcast. We find
      // ValGroups containing broadcasts and check if Bulk-parallelized ValGroups
      // are reachable from them through transitive dependencies.
    
      // Build PERMISSIVE IdModel graph - maps broadcast to non-broadcast through
      // transformations
      IdModel id_model(smem_tv->fusion(), /*build_graphs=*/false);
      id_model.maybeBuildGraph(IdMappingMode::PERMISSIVE);
      const ValGraph& permissive_graph =
          id_model.idGraph(IdMappingMode::PERMISSIVE);
    
      // Collect ValGroups containing broadcast IDs
      ValGroups broadcast_groups;
      for (const ValGroup& val_group :
           permissive_graph.disjointValSets().disjointSets()) {
        bool has_broadcast =
            std::any_of(val_group->begin(), val_group->end(), [](Val* val) {
              return val->isA<IterDomain>() && val->as<IterDomain>()->isBroadcast();
            });
        if (has_broadcast) {
          broadcast_groups.pushBack(val_group);
        }
      }
    
      // Collect ValGroups containing Bulk-parallelized IDs
      ValGroups bulk_groups;
      for (const ValGroup& val_group :
           permissive_graph.disjointValSets().disjointSets()) {
        bool has_bulk =
            std::any_of(val_group->begin(), val_group->end(), [](Val* val) {
              return val->isA<IterDomain>() &&
                  val->as<IterDomain>()->getParallelType() == ParallelType::Bulk;
            });
        if (has_bulk) {
          bulk_groups.pushBack(val_group);
        }
      }
    
      // Check if any Bulk ValGroup is reachable from broadcast ValGroups
      // This captures both direct and transitive dependencies (broadcast ->
      // intermediate -> Bulk)
      auto reachable_bulk_groups = getReachableValsFrom<ValGraphBFS>(
          broadcast_groups.vector(),
          bulk_groups.vector(),
          Direction::Forward,
          permissive_graph);
    
      // If any Bulk group is reachable from broadcasts, it's an error
      if (!reachable_bulk_groups.empty()) {
        NVF_ERROR(
            false,
            "Broadcast may interfere with TMA loading of ",
            smem_tv->toString(),
            ". Bulk-parallelized dimensions are reachable from broadcast "
            "dimensions through transformations.");
      }
    }
    MMA detection condition

    The broadcast validation is only performed when !GpuLower::current()->hasMma(). This suggests that MMA operations somehow handle broadcast differently. Need to verify this condition is correct and that MMA operations don't have the same broadcast interference issues.

    if (!GpuLower::current()->hasMma()) {
      validateTMAConsumerBroadcasts(smem_tv);
    }
    Test case validation

    The test BroadcastDownstreamOfTMALoad1dTileValid appears to be marked as TODO and currently expected to fail, but the comment suggests it should be valid. This test case needs clarification and proper resolution to ensure the validation logic is working correctly.

      // TODO: This test is valid. See validateTMAConsumerBroadcasts for more
      // details.
      EXPECT_THAT(
          [&]() {
            KernelExecutor ke;
            ke.compile(&fusion, {t0, t1});
          },
          ::testing::ThrowsMessage<nvfuser::nvfError>(
              ::testing::HasSubstr("Broadcast may interfere with TMA loading")));
    }

    Test failures

    • (Medium, 267) NVFuser TMA broadcast assertion failures across TmaWarpSpecializedTest suites

      Test Name GB200 H100 Source
      CombinedSchedulerTest.ThunderLayerNormBackward Link
      Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_0_M_128_N_1024_WarpSpecializedOnTIDyRegisterSharing_NoneStageSlicePosition_None_CpAsyncBulkTensorTile Link
      Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_0_M_500_N_1024_WarpSpecializedOnTIDyRegisterSharing_NoneStageSlicePosition_None_CpAsyncBulkTensorTile Link
      Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_1_M_128_N_2048_WarpSpecializedOnTIDyRegisterSharing_64_168StageSlicePosition_None_CpAsyncBulkTensorTile Link
      Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_neg1_M_500_N_1024_WarpSpecializedOnTIDyRegisterSharing_64_168StageSlicePosition_None_CpAsyncBulkTensorTile Link
      Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_neg2_M_1024_N_1024_WarpSpecializedOnTIDyRegisterSharing_64_168StageSlicePosition_None_CpAsyncBulkTensorTile Link
      Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_neg2_M_128_N_1024_PipelinedMBarrierForWAR_CpAsyncBulkTensorTile Link
      Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_neg2_M_128_N_2048_Pipelined_CpAsyncBulkTensorTile Link
      Hopper/TmaCircularBufferingTest.Matmul/stage_2_prefetch_neg2_M_500_N_2048_WarpSpecializedOnTIDyRegisterSharing_64_168StageSlicePosition_None_CpAsyncBulkTensorTile Link
      Hopper/TmaCircularBufferingTest.Matmul/stage_4_prefetch_0_M_1024_N_1024_PipelinedMBarrierForWAR_CpAsyncBulkTensorTile Link
      ... with 199 more test failures omitted. Check internal logs.
    • (Medium, 2) nvFuser internal TMA broadcast assertion failure in tests.python.test_schedule_ops

      Test Name GB200 H100 Source
      tests.python.test_schedule_ops.TestScheduleOps.test_var_mean_tma_user_schedule
    • (Medium, 1) NVFP4 grouped_mm numerical mismatch in test_with_id_model_indexer

      Test Name GB200 Source
      tests.python.direct.test_with_id_model_indexer.test_layout_op_and_cutlass_nvfp4_grouped_mm[out_dtype=torch.bfloat16-tokens_per_expert_neg_one=[115, 144, 8]-config=[1024, 128, 256]]

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants