Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)

// Represent the SMEM buffers for A and B
Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_N, NumMma_K, Tiles_K)

//
// Mma partitioning for A and B
Expand All @@ -218,7 +218,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)

if (thread0()) {
print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_256,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0)
print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0)
} __syncthreads();
Expand All @@ -230,7 +230,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
// - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively
// - The first mode of each descriptor represents the SMEM for a single MMA operation
Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K)
Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K)
Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_N, NumMma_K, Tiles_K)

// TMEM Allocation
// On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM).
Expand All @@ -251,7 +251,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)

if (thread0()) {
print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2)
print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0)
Expand Down Expand Up @@ -300,14 +300,22 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K)
create_tma_multicast_mask<0,2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk);

// Calculate total bytes that TMA will transfer each tile to track completion, accounting for TMA.2SM
// NOTE: each tma_partition slice is an OFFSET VIEW into this CTA's stage
// buffer; sizeof(make_tensor_like(slice)) is already the full bytes
// delivered into this CTA. Do NOT multiply by the number of multicast
// participants, and do NOT size B from the full-N tile (the 2x1SM atom
// gives each CTA an N/2 slice) -- both over-expect the barrier, which
// then never fires: a silent deadlock, not an error. The only multiplier
// is size<0>(cluster_layout_vmnk) (== 2): the leader's barrier counts
// the PAIR's bytes for 2SM TMA loads.
int tma_transaction_bytes = size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tAsA))
+ size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tBsB));

if (thread0()) {
print("tAgA:\t"); print(tAgA); print("\n"); // tAgA: ArithTuple(_0,0) o (((_64,_128),_1),4):(((_1@0,_1@1),_0),_64@0)
print("tAsA:\t"); print(tAsA); print("\n"); // tAsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_8192,_1)):((_1,_0))
print("tBgB:\t"); print(tBgB); print("\n"); // tBgB: ArithTuple(_0,0) o (((_64,_256),_1),4):(((_1@0,_1@1),_0),_64@0)
print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_16384,_1)):((_1,_0))
print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_8192,_1)):((_1,_0))
printf("tma_transaction_bytes: %d\n", tma_transaction_bytes);
printf("tma_mcast_mask_a: %x\n", tma_mcast_mask_a);
printf("tma_mcast_mask_b: %x\n", tma_mcast_mask_b);
Expand Down Expand Up @@ -503,7 +511,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,

// Print and inspect mma_shape_A, and mma_shape_B for this example.
print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4)
print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4)
print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_128,_16),_1,_4)

// A and B tensors are swizzled in SMEM to improve MMA performance.
// * However, expressing swizzled layouts is very hard.
Expand All @@ -513,7 +521,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A,

// Print and inspect sA_layout and sB_layout for this example.
print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16)
print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16)
print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16)

// Now we can find the SMEM allocation size
using SMEMStorage = SharedStorage<TypeA, TypeB, decltype(sA_layout), decltype(sB_layout)>;
Expand Down