Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align smem buffer for TMA store at 128B #4071

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Mar 13, 2025

Previously, #3023 addressed this for TMA loads. This PR tweaks it to also cover the producer smem buffer.

Fixes #3966

Previously, #3023 addressed this for TMA loads. This PR tweaks it to
also cover the producer smem buffer.
Copy link

github-actions bot commented Mar 13, 2025

Review updated until commit 92ebf85

Description

  • Align smem buffer for TMA store at 128B

  • Add test for TMA store alignment

  • Fix test to ensure proper functionality


Changes walkthrough 📝

Relevant files
Enhancement
alias_memory.cpp
Update allocation info for CpAsyncBulk                                     

csrc/device_lower/pass/alias_memory.cpp

  • Update allocation info to consider CpAsyncBulk in uses
+4/-1     
Tests
test_matmul.cpp
Add and fix TMA store alignment test                                         

tests/cpp/test_matmul.cpp

  • Add test for TMA store alignment
  • Fix test to ensure proper functionality
  • +62/-0   

    PR Reviewer Guide 🔍

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review

    Logic Verification

    Ensure that the new logic correctly identifies cases where is_cp_async_bulk should be set to true, especially considering the addition of std::any_of to check uses.

    (ir_utils::isCpAsyncBulk(tv->definition()) ||
     std::any_of(tv->uses().begin(), tv->uses().end(), [](Expr* expr) {
       return ir_utils::isCpAsyncBulk(expr);
     })));
    Test Coverage

    Verify that the new test case AlignTMAStore adequately covers the scenario where TMA store smem buffers need to be aligned, and that it fails without the PR changes.

    TEST_F(HopperMatmulTest, AlignTMAStore) {
      Fusion fusion;
      FusionGuard fg(&fusion);
    
      const auto dtype = DataType::Half;
    
      auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // M, K
      auto tv1 = makeContigConcreteTensor({1, -1, -1}, dtype); // K, N
      fusion.addInput(tv0);
      fusion.addInput(tv1);
    
      auto tv2 = fusedMultiplySum(tv0, tv1, {1});
    
      // Reorder the accumulator as [M, N, K]
      // [M, K, N] -> [M, N, K]
      tv2->reorder({{-2, -1}});
      tv2->commitLeafToLogical();
    
      auto tv3 = castOp(DataType::Half, tv2);
      fusion.addOutput(tv3);
    
      MatMulTileOptions gemm_tile;
      gemm_tile.cta_tile = GemmTile(192, 208, 64);
      gemm_tile.warp_tile = GemmTile(192, 104, 64);
    
      MatmulParams mparams;
      mparams.supported_vec_size = {8, 8, 8};
      mparams.mma_macro = MmaMacro::Hopper_64_104_16;
      mparams.tile_sizes = gemm_tile;
      mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor;
      mparams.async_gmem_load_operands = true;
      mparams.circular_buffer_options.circular_buffer_smem_write = false;
      mparams.circular_buffer_options.circular_buffer_smem_read = false;
      mparams.circular_buffer_options.smem_circular_buffer_stage = 4;
      mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1;
      mparams.splitk_factor = 1;
      mparams.use_smem_epilogue = true;
      mparams.cluster_dims = {1, 1, 1};
      mparams.promote_prologue_smem_reuse = true;
    
      constexpr int64_t M = 5320, N = 33928, K = 3464;
    
      auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA);
    
      auto a_ref = at::randn({M, K, 1}, options);
      auto b_ref = at::randn({1, K, N}, options);
      auto out_ref = at::matmul(a_ref.squeeze(), b_ref.squeeze()).to(at::kHalf);
      const std::vector<c10::IValue> inputs = {a_ref, b_ref};
    
      mparams.cparams.index_type = DataType::Int32;
    
      SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul)
          ->schedule(&fusion, &mparams);
    
      KernelExecutor ke;
      ke.compile(&fusion, inputs);
      auto outputs = ke.run(inputs);
      EXPECT_TRUE(at::allclose(outputs[0].as<at::Tensor>(), out_ref, 1e-5, 1e-5));
    }

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    None yet
    Projects
    None yet
    Development

    Successfully merging this pull request may close these issues.

    cudaErrorMisalignedAddress when sweeping matmul problems with NN, TN, and TT layouts.
    1 participant