From 7f5f5335ad332334dce6681f2aa945f40c8b8b08 Mon Sep 17 00:00:00 2001 From: Chris Fregly Date: Thu, 11 Jun 2026 15:36:32 -0700 Subject: [PATCH] docs(cute tutorial 04): fix stale B-operand print annotations (2x1SM atom splits B N/2-per-CTA) and warn about the expect-tx deadlock pitfalls The SM100_MMA_F16BF16_2x1SM_SS atom splits the B operand N/2 per CTA across the pair; the inline '// printed:' annotations still showed full-N shapes (((_256,_16),...) where partition_shape_B actually returns ((_128,_16),...) with the tutorial's own TiledMMA, verified on CUTLASS main and 4.2.0). Also fixes the NumMma_M -> NumMma_N comment typos on the B-tensor modes, and adds a warning above the tma_transaction_bytes computation: the tma_partition slice is an offset view (do not multiply by multicast participants) and B is the N/2 slice (do not size from the full tile) -- both mistakes over-expect the barrier, which then never fires. Annotation/comment-only change. --- .../blackwell/04_mma_tma_2sm_sm100.cu | 22 +++++++++++++------ 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu b/examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu index 2d4799f9fe..ca25ae2587 100644 --- a/examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu +++ b/examples/cute/tutorial/blackwell/04_mma_tma_2sm_sm100.cu @@ -203,7 +203,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // Represent the SMEM buffers for A and B Tensor tCsA = shared_storage.tensor_sA(); // (MmaA, NumMma_M, NumMma_K, Tiles_K) - Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_M, NumMma_K, Tiles_K) + Tensor tCsB = shared_storage.tensor_sB(); // (MmaB, NumMma_N, NumMma_K, Tiles_K) // // Mma partitioning for A and B @@ -218,7 +218,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) if (thread0()) { print("tCgA:\t"); print(tCgA); print("\n"); // tCgA: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0) - print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_256,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0) + print("tCgB:\t"); print(tCgB); print("\n"); // tCgB: ArithTuple(_0,0) o ((_128,_16),_1,_4,4):((_1@1,_1@0),_0,_16@0,_64@0) print("tCgC:\t"); print(tCgC); print("\n"); // tCgC: gmem_ptr[32b](GMEM_ADDR_C + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0) print("tCgD:\t"); print(tCgD); print("\n"); // tCgD: gmem_ptr[32b](GMEM_ADDR_D + offset_for_mma_tile + offset_for_mma) o ((_128,_256),_1,_1):((256,_1),_0,_0) } __syncthreads(); @@ -230,7 +230,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) // - tCrA and tCrB provide descriptor views of tCsA and tCsB respectively // - The first mode of each descriptor represents the SMEM for a single MMA operation Tensor tCrA = cta_mma.make_fragment_A(tCsA); // (MmaA, NumMma_M, NumMma_K, Tiles_K) - Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_M, NumMma_K, Tiles_K) + Tensor tCrB = cta_mma.make_fragment_B(tCsB); // (MmaB, NumMma_N, NumMma_K, Tiles_K) // TMEM Allocation // On SM100 architecture, accumulators are stored exclusively in tensor memory (TMEM). @@ -251,7 +251,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) if (thread0()) { print("tCsA:\t"); print(tCsA); print("\n"); // tCsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_128,_16),_1,_4):((_64,_1),_0,_16) - print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_256,_16),_1,_4):((_64,_1),_0,_16) + print("tCsB:\t"); print(tCsB); print("\n"); // tCsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_128,_16),_1,_4):((_64,_1),_0,_16) print("tCrA:\t"); print(tCrA); print("\n"); // tCrA: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2) print("tCrB:\t"); print(tCrB); print("\n"); // tCrB: UMMA::DescriptorIterator o (_1,_1,_4):(_0,_0,_2) print("tCtAcc:\t"); print(tCtAcc); print("\n"); // tCtAcc: tmem_[32b](TMEM_ADDR) o ((_128,_256),_1,_1):((_65536,_1),_0,_0) @@ -300,6 +300,14 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) create_tma_multicast_mask<0,2>(cluster_layout_vmnk, cta_in_cluster_coord_vmnk); // Calculate total bytes that TMA will transfer each tile to track completion, accounting for TMA.2SM + // NOTE: each tma_partition slice is an OFFSET VIEW into this CTA's stage + // buffer; sizeof(make_tensor_like(slice)) is already the full bytes + // delivered into this CTA. Do NOT multiply by the number of multicast + // participants, and do NOT size B from the full-N tile (the 2x1SM atom + // gives each CTA an N/2 slice) -- both over-expect the barrier, which + // then never fires: a silent deadlock, not an error. The only multiplier + // is size<0>(cluster_layout_vmnk) (== 2): the leader's barrier counts + // the PAIR's bytes for 2SM TMA loads. int tma_transaction_bytes = size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tAsA)) + size<0>(cluster_layout_vmnk) * sizeof(make_tensor_like(tBsB)); @@ -307,7 +315,7 @@ gemm_device(ATensor mA, // (Gemm_M, Gemm_K) print("tAgA:\t"); print(tAgA); print("\n"); // tAgA: ArithTuple(_0,0) o (((_64,_128),_1),4):(((_1@0,_1@1),_0),_64@0) print("tAsA:\t"); print(tAsA); print("\n"); // tAsA: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_A) o ((_8192,_1)):((_1,_0)) print("tBgB:\t"); print(tBgB); print("\n"); // tBgB: ArithTuple(_0,0) o (((_64,_256),_1),4):(((_1@0,_1@1),_0),_64@0) - print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_16384,_1)):((_1,_0)) + print("tBsB:\t"); print(tBsB); print("\n"); // tBsB: Sw<3,4,3>_smem_ptr[16b](SMEM_ADDR_B) o ((_8192,_1)):((_1,_0)) printf("tma_transaction_bytes: %d\n", tma_transaction_bytes); printf("tma_mcast_mask_a: %x\n", tma_mcast_mask_a); printf("tma_mcast_mask_b: %x\n", tma_mcast_mask_b); @@ -503,7 +511,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A, // Print and inspect mma_shape_A, and mma_shape_B for this example. print("mma_shape_A:\t"); print(mma_shape_A); print("\n"); // mma_shape_A: ((_128,_16),_1,_4) - print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_256,_16),_1,_4) + print("mma_shape_B:\t"); print(mma_shape_B); print("\n"); // mma_shape_B: ((_128,_16),_1,_4) // A and B tensors are swizzled in SMEM to improve MMA performance. // * However, expressing swizzled layouts is very hard. @@ -513,7 +521,7 @@ void gemm_host_f16xf16_f32_f32_tnt(TypeA const* device_ptr_A, LayoutA layout_A, // Print and inspect sA_layout and sB_layout for this example. print("sA_layout:\t"); print(sA_layout); print("\n"); // sA_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16) - print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_256,_16),_1,_4):((_64,_1),_0,_16) + print("sB_layout:\t"); print(sB_layout); print("\n"); // sB_layout: Sw<3,4,3> o smem_ptr[16b](unset) o ((_128,_16),_1,_4):((_64,_1),_0,_16) // Now we can find the SMEM allocation size using SMEMStorage = SharedStorage;