Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize TMA Store logic to handle pipelining and aliasing. #3961

Open
rdspring1 opened this issue Feb 25, 2025 · 0 comments
Open

Optimize TMA Store logic to handle pipelining and aliasing. #3961

rdspring1 opened this issue Feb 25, 2025 · 0 comments
Assignees
Labels

Comments

@rdspring1
Copy link
Collaborator

Pipelining

  • Move TMA store wait to occur before next write to shared memory in persistent for-loop.

Baseline matmul (Without Aliasing)

  • Insert cpAsyncBulkWaitGroup before stmatrix.

Epilogue (With Aliasing)

  • For epilogue matmul, the output tile can be aliased by epilogue input. The load for epilogue input occurs in load warp group.

Other

  • Only the selected thread issuing the TMA store should issue TMA store commit.
  • Add fenceAsyncProxy intelligently in insert_syncs pass. e.g., Only once when writing to shared memory before it is used by TMA store.

Cuda Kernel

 for (persistent-loop) {
      // epilogue computation

      cpAsyncBulkWaitGroup<0LL>();  // <<<<<< move TMA store wait to occur before stmatrix use
#pragma unroll
      for (nvfuser_index_t i50 = 0; i50 < 16; ++i50) {
        if ((b27 && (i28 < (-(16 * i50))))) {
          stmatrix4(
              (uint32_t)((
                  toSmem(T8) +
                  ((((nvfuser_index_t)threadIdx.y) * 32768) +
                   (((i50 / 4) * 8192) +
                    ((i11 * 128) +
                     (((((((nvfuser_index_t)threadIdx.x) % 32) / 16) +
                        ((i50 % 4) * 2)) ^
                       (i11 % 8)) *
                      16)))))),
              (*reinterpret_cast<Array<uint32_t, 4, 1>*>(&T7[(8 * i50)])));
        }
      }
      asm volatile("bar.sync 0, %0;" : : "r"(num_threads) : "memory");
      fenceAsyncProxy();  // <<<< 
      if (((Hopper::electSync(4294967295U) && b15) && b18)) {
#pragma unroll
        for (nvfuser_index_t i51 = 0; i51 < 4; ++i51) {
          Hopper::cpAsyncBulkTensorTileS2G(
              (Hopper::CpAsyncBulkTensorTileS2GIndex<2>{
                  ptr13,
                  (Array<nvfuser_index_t, 2, 1>{(i24 + (64 * i51)), i26})}),
              (i12 + (8192 * i51)));
        }
        cpAsyncBulkCommitGroup();  // Only apply tma store commit with issuing thread.
      }
    }
  }
@rdspring1 rdspring1 self-assigned this Feb 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant