From cf15bc7b6422ffc7d4226be2fb1f598fe12b455c Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Mon, 27 Jan 2025 04:40:52 -0800 Subject: [PATCH 1/4] [AMD][BACKEND] Disable pingpong with non-local_load input. Pingpong pass only expects to handle local_load ops as A/B Avoid using the trasform when different op is detected. --- .../TritonAMDGPUTransforms/BlockPingpong.cpp | 60 ++++++++++--------- 1 file changed, 32 insertions(+), 28 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index 63681e941816..69cd89c6d0bb 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -151,35 +151,39 @@ LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v, int64_t sliceWidth) { SmallVector slices; SmallVector subviews; - auto memDesc = v.getDefiningOp()->getOperand(0); - auto type = cast(memDesc.getType()); - SmallVector shape = llvm::to_vector(type.getShape()); - Type elementType = type.getElementType(); - int64_t kIdx = opIdx == 0 ? 1 : 0; - shape[kIdx] = sliceWidth; - // Each slice cannot be smaller than the smallest supported mfma width. - if (sliceWidth < 16) - return failure(); - auto dotOperandEnc = ttg::DotOperandEncodingAttr::get( - builder.getContext(), opIdx, dotEncoding, kWidth); - auto subviewDescType = ttg::MemDescType::get( - shape, elementType, type.getEncoding(), type.getMemorySpace(), - type.getMutableMemory(), type.getAllocShape()); - for (int i = 0; i < numSlices; i++) { - SmallVector offsetsVal; - SmallVector offsets = {0, 0}; - offsets[kIdx] = i; - for (int64_t off : offsets) { - offsetsVal.push_back(constOffsets[off]); + // TODO: support transformed input to dot + if (auto maybeLocal = v.getDefiningOp()) { + auto memDesc = maybeLocal.getSrc(); + auto type = cast(memDesc.getType()); + SmallVector shape = llvm::to_vector(type.getShape()); + Type elementType = type.getElementType(); + int64_t kIdx = opIdx == 0 ? 1 : 0; + shape[kIdx] = sliceWidth; + // Each slice cannot be smaller than the smallest supported mfma width. + if (sliceWidth < 16) + return failure(); + auto dotOperandEnc = ttg::DotOperandEncodingAttr::get( + builder.getContext(), opIdx, dotEncoding, kWidth); + auto subviewDescType = ttg::MemDescType::get( + shape, elementType, type.getEncoding(), type.getMemorySpace(), + type.getMutableMemory(), type.getAllocShape()); + for (int i = 0; i < numSlices; i++) { + SmallVector offsetsVal; + SmallVector offsets = {0, 0}; + offsets[kIdx] = i; + for (int64_t off : offsets) { + offsetsVal.push_back(constOffsets[off]); + } + Value newSmem = builder.create( + v.getLoc(), subviewDescType, memDesc, offsetsVal); + Value prefetchSlice = builder.create( + v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), + newSmem); + subviews.push_back(newSmem.getDefiningOp()); + slices.push_back(prefetchSlice.getDefiningOp()); } - Value newSmem = builder.create( - v.getLoc(), subviewDescType, memDesc, offsetsVal); - Value prefetchSlice = builder.create( - v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), - newSmem); - subviews.push_back(newSmem.getDefiningOp()); - slices.push_back(prefetchSlice.getDefiningOp()); - } + } else + return failure(); subViewOps.push_back(subviews); loadSliceOps.push_back(slices); return success(); From 596ef89b8feb314b0ea1aab134689d455bf467a2 Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Tue, 28 Jan 2025 16:15:19 -0800 Subject: [PATCH 2/4] address review - improve code - add test --- test/TritonGPU/amd/amd-block-pingpong.mlir | 72 +++++++++++++++++++ .../TritonAMDGPUTransforms/BlockPingpong.cpp | 62 ++++++++-------- 2 files changed, 103 insertions(+), 31 deletions(-) diff --git a/test/TritonGPU/amd/amd-block-pingpong.mlir b/test/TritonGPU/amd/amd-block-pingpong.mlir index a761cac37666..7bc7871e06d4 100644 --- a/test/TritonGPU/amd/amd-block-pingpong.mlir +++ b/test/TritonGPU/amd/amd-block-pingpong.mlir @@ -288,6 +288,78 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ // ----- +// CHECK-LABEL: pingpong_medium_cast +// CHECK-COUNT-2: local_load +// CHECK-NOT: setprio +// CHECK-NOT: barrier + +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> +#loc = loc("/home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py":6:0) +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> +#shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @pingpong_medium_cast(%arg0: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma> + %c1_i32 = arith.constant 1 : i32 + %cst_0 = arith.constant dense<64> : tensor<64x128xi32, #blocked> + %cst_1 = arith.constant dense<64> : tensor<256x64xi32, #blocked1> + %c0_i32 = arith.constant 0 : i32 + %c64_i32 = arith.constant 64 : i32 + %0 = tt.splat %arg0 : !tt.ptr -> tensor<256x1x!tt.ptr, #blocked1> + %1 = tt.get_program_id x : i32 + %2 = tt.splat %1 : i32 -> tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %3 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %4 = arith.addi %2, %3 : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> + %5 = tt.expand_dims %4 {axis = 1 : i32} : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %6 = tt.splat %arg6 : i32 -> tensor<256x1xi32, #blocked1> + %7 = arith.muli %5, %6 : tensor<256x1xi32, #blocked1> + %8 = tt.addptr %0, %7 : tensor<256x1x!tt.ptr, #blocked1>, tensor<256x1xi32, #blocked1> + %9 = tt.broadcast %8 : tensor<256x1x!tt.ptr, #blocked1> -> tensor<256x64x!tt.ptr, #blocked1> + %10 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> + %11 = tt.expand_dims %10 {axis = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x64xi32, #blocked1> + %12 = tt.broadcast %11 : tensor<1x64xi32, #blocked1> -> tensor<256x64xi32, #blocked1> + %13 = tt.addptr %9, %12 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %14 = tt.splat %arg1 : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked> + %15 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %16 = tt.expand_dims %15 {axis = 1 : i32} : tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %17 = tt.addptr %14, %16 : tensor<64x1x!tt.ptr, #blocked>, tensor<64x1xi32, #blocked> + %18 = tt.broadcast %17 : tensor<64x1x!tt.ptr, #blocked> -> tensor<64x128x!tt.ptr, #blocked> + %19 = tt.splat %arg7 : i32 -> tensor<64x128xi32, #blocked> + %20 = tt.addptr %18, %19 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %21 = ttg.local_alloc : () -> !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> + %22 = ttg.local_alloc : () -> !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> + %23 = ttg.memdesc_subview %21[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %24 = ttg.memdesc_subview %22[%c0_i32, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> + %25:6 = scf.for %arg10 = %c0_i32 to %c64_i32 step %c1_i32 iter_args(%arg11 = %cst, %arg12 = %13, %arg13 = %20, %arg14 = %c0_i32, %arg15 = %23, %arg16 = %24) -> (tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable>) : i32 { + %26 = tt.addptr %arg12, %cst_1 : tensor<256x64x!tt.ptr, #blocked1>, tensor<256x64xi32, #blocked1> + %27 = tt.load %26 : tensor<256x64x!tt.ptr, #blocked1> + %28 = tt.addptr %arg13, %cst_0 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + %29 = tt.load %28 : tensor<64x128x!tt.ptr, #blocked> + %cast2 = tt.bitcast %29 : tensor<64x128xf16, #blocked> -> tensor<64x128xi16, #blocked> + %30 = ttg.local_load %arg15 : !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %31 = ttg.local_load %arg16 : !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> -> tensor<64x128xi16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %cast = tt.bitcast %31 : tensor<64x128xi16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %32 = tt.dot %30, %cast, %arg11 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma> + %33 = arith.addi %arg14, %c1_i32 : i32 + %34 = arith.cmpi slt, %33, %c1_i32 : i32 + %35 = arith.select %34, %33, %c0_i32 : i32 + %36 = ttg.memdesc_subview %21[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_store %27, %36 : tensor<256x64xf16, #blocked1> -> !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable> + %37 = ttg.memdesc_subview %22[%35, %c0_i32, %c0_i32] : !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> -> !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> + ttg.local_store %cast2, %37 : tensor<64x128xi16, #blocked> -> !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> + scf.yield %32, %26, %28, %35, %36, %37 : tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked1>, tensor<64x128x!tt.ptr, #blocked>, i32, !ttg.memdesc<256x64xf16, #shared, #ttg.shared_memory, mutable>, !ttg.memdesc<64x128xi16, #shared1, #ttg.shared_memory, mutable> + } + ttg.local_dealloc %21 : !ttg.memdesc<1x256x64xf16, #shared, #ttg.shared_memory, mutable> + ttg.local_dealloc %22 : !ttg.memdesc<1x64x128xi16, #shared1, #ttg.shared_memory, mutable> + tt.return + } +} + + +// ----- + + // CHECK-LABEL: pingpong_reject // CHECK-COUNT-2: local_load // CHECK-NOT: local_load diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index 69cd89c6d0bb..718d62a81f59 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -152,38 +152,38 @@ LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v, SmallVector slices; SmallVector subviews; // TODO: support transformed input to dot - if (auto maybeLocal = v.getDefiningOp()) { - auto memDesc = maybeLocal.getSrc(); - auto type = cast(memDesc.getType()); - SmallVector shape = llvm::to_vector(type.getShape()); - Type elementType = type.getElementType(); - int64_t kIdx = opIdx == 0 ? 1 : 0; - shape[kIdx] = sliceWidth; - // Each slice cannot be smaller than the smallest supported mfma width. - if (sliceWidth < 16) - return failure(); - auto dotOperandEnc = ttg::DotOperandEncodingAttr::get( - builder.getContext(), opIdx, dotEncoding, kWidth); - auto subviewDescType = ttg::MemDescType::get( - shape, elementType, type.getEncoding(), type.getMemorySpace(), - type.getMutableMemory(), type.getAllocShape()); - for (int i = 0; i < numSlices; i++) { - SmallVector offsetsVal; - SmallVector offsets = {0, 0}; - offsets[kIdx] = i; - for (int64_t off : offsets) { - offsetsVal.push_back(constOffsets[off]); - } - Value newSmem = builder.create( - v.getLoc(), subviewDescType, memDesc, offsetsVal); - Value prefetchSlice = builder.create( - v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), - newSmem); - subviews.push_back(newSmem.getDefiningOp()); - slices.push_back(prefetchSlice.getDefiningOp()); - } - } else + ttg::LocalLoadOp localLoad; + if !(localLoad = v.getDefiningOp()) + return failure(); + auto memDesc = localLoad.getSrc(); + auto type = cast(memDesc.getType()); + SmallVector shape = llvm::to_vector(type.getShape()); + Type elementType = type.getElementType(); + int64_t kIdx = opIdx == 0 ? 1 : 0; + shape[kIdx] = sliceWidth; + // Each slice cannot be smaller than the smallest supported mfma width. + if (sliceWidth < 16) return failure(); + auto dotOperandEnc = ttg::DotOperandEncodingAttr::get( + builder.getContext(), opIdx, dotEncoding, kWidth); + auto subviewDescType = ttg::MemDescType::get( + shape, elementType, type.getEncoding(), type.getMemorySpace(), + type.getMutableMemory(), type.getAllocShape()); + for (int i = 0; i < numSlices; i++) { + SmallVector offsetsVal; + SmallVector offsets = {0, 0}; + offsets[kIdx] = i; + for (int64_t off : offsets) { + offsetsVal.push_back(constOffsets[off]); + } + Value newSmem = builder.create( + v.getLoc(), subviewDescType, memDesc, offsetsVal); + Value prefetchSlice = builder.create( + v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), + newSmem); + subviews.push_back(newSmem.getDefiningOp()); + slices.push_back(prefetchSlice.getDefiningOp()); + } subViewOps.push_back(subviews); loadSliceOps.push_back(slices); return success(); From e265975ceb827d1e7069201607d2a5daeaaffeee Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Tue, 28 Jan 2025 16:34:30 -0800 Subject: [PATCH 3/4] fix --- third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index 718d62a81f59..d17e39aa70ca 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -152,8 +152,8 @@ LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v, SmallVector slices; SmallVector subviews; // TODO: support transformed input to dot - ttg::LocalLoadOp localLoad; - if !(localLoad = v.getDefiningOp()) + auto localLoad = v.getDefiningOp(); + if (!localLoad) return failure(); auto memDesc = localLoad.getSrc(); auto type = cast(memDesc.getType()); From 35f80978175559b8807be349ba7bffcf98c6baaf Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Wed, 29 Jan 2025 03:35:18 -0800 Subject: [PATCH 4/4] Cleanup tests --- test/TritonGPU/amd/amd-block-pingpong.mlir | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/TritonGPU/amd/amd-block-pingpong.mlir b/test/TritonGPU/amd/amd-block-pingpong.mlir index 7bc7871e06d4..684132d0984c 100644 --- a/test/TritonGPU/amd/amd-block-pingpong.mlir +++ b/test/TritonGPU/amd/amd-block-pingpong.mlir @@ -128,7 +128,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> -#loc = loc("/home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py":6:0) #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> #shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { @@ -227,7 +226,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> -#loc = loc("/home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py":6:0) #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> #shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { @@ -295,7 +293,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> -#loc = loc("/home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py":6:0) #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> #shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { @@ -368,7 +365,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.targ #blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}> #blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}> -#loc = loc("/home/jung/rocm/triton/python/perf-kernels/tools/tune_gemm/matmul_kernel.py":6:0) #mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>#shared = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}> #shared1 = #ttg.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {